feat: support group_size for search_group_by(#33544) (#33720)

related: #33544

mainly changes in three aspects:

1. enable setting group_size for group by function
2. separate normal reduce and group by reduce
3. eleminate uncessary padding in search result for reducing

Signed-off-by: MrPresent-Han <chun.han@gmail.com>
Co-authored-by: MrPresent-Han <chun.han@gmail.com>
This commit is contained in:
Chun Han 2024-07-12 10:17:36 +08:00 committed by GitHub
parent 5bb0d21e32
commit f00c529aea
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
28 changed files with 708 additions and 385 deletions

View File

@ -25,8 +25,9 @@
namespace milvus { namespace milvus {
struct SearchInfo { struct SearchInfo {
int64_t topk_; int64_t topk_{0};
int64_t round_decimal_; int64_t group_size_{1};
int64_t round_decimal_{0};
FieldId field_id_; FieldId field_id_;
MetricType metric_type_; MetricType metric_type_;
knowhere::Json search_params_; knowhere::Json search_params_;

View File

@ -195,6 +195,7 @@ struct SearchResult {
std::vector<float> distances_; std::vector<float> distances_;
std::vector<int64_t> seg_offsets_; std::vector<int64_t> seg_offsets_;
std::optional<std::vector<GroupByValueType>> group_by_values_; std::optional<std::vector<GroupByValueType>> group_by_values_;
std::optional<int64_t> group_size_;
// first fill data during fillPrimaryKey, and then update data after reducing search results // first fill data during fillPrimaryKey, and then update data after reducing search results
std::vector<PkType> primary_keys_; std::vector<PkType> primary_keys_;
@ -209,7 +210,7 @@ struct SearchResult {
std::map<FieldId, std::unique_ptr<milvus::DataArray>> output_fields_data_; std::map<FieldId, std::unique_ptr<milvus::DataArray>> output_fields_data_;
// used for reduce, filter invalid pk, get real topks count // used for reduce, filter invalid pk, get real topks count
std::vector<size_t> topk_per_nq_prefix_sum_; std::vector<size_t> topk_per_nq_prefix_sum_{};
//Vector iterators, used for group by //Vector iterators, used for group by
std::optional<std::vector<std::shared_ptr<VectorIterator>>> std::optional<std::vector<std::shared_ptr<VectorIterator>>>

View File

@ -26,7 +26,7 @@ set(MILVUS_QUERY_SRCS
SearchOnIndex.cpp SearchOnIndex.cpp
SearchBruteForce.cpp SearchBruteForce.cpp
SubSearchResult.cpp SubSearchResult.cpp
GroupByOperator.cpp groupby/SearchGroupByOperator.cpp
PlanProto.cpp PlanProto.cpp
) )
add_library(milvus_query ${MILVUS_QUERY_SRCS}) add_library(milvus_query ${MILVUS_QUERY_SRCS})

View File

@ -90,6 +90,12 @@ CreateSearchPlanByExpr(const Schema& schema,
return ProtoParser(schema).CreatePlan(plan_node); return ProtoParser(schema).CreatePlan(plan_node);
} }
std::unique_ptr<Plan>
CreateSearchPlanFromPlanNode(const Schema& schema,
const proto::plan::PlanNode& plan_node) {
return ProtoParser(schema).CreatePlan(plan_node);
}
std::unique_ptr<RetrievePlan> std::unique_ptr<RetrievePlan>
CreateRetrievePlanByExpr(const Schema& schema, CreateRetrievePlanByExpr(const Schema& schema,
const void* serialized_expr_plan, const void* serialized_expr_plan,

View File

@ -32,6 +32,10 @@ CreateSearchPlanByExpr(const Schema& schema,
const void* serialized_expr_plan, const void* serialized_expr_plan,
const int64_t size); const int64_t size);
std::unique_ptr<Plan>
CreateSearchPlanFromPlanNode(const Schema& schema,
const proto::plan::PlanNode& plan_node);
std::unique_ptr<PlaceholderGroup> std::unique_ptr<PlaceholderGroup>
ParsePlaceholderGroup(const Plan* plan, ParsePlaceholderGroup(const Plan* plan,
const uint8_t* blob, const uint8_t* blob,

View File

@ -209,7 +209,11 @@ ProtoParser::PlanNodeFromProto(const planpb::PlanNode& plan_node_proto) {
if (query_info_proto.group_by_field_id() > 0) { if (query_info_proto.group_by_field_id() > 0) {
auto group_by_field_id = FieldId(query_info_proto.group_by_field_id()); auto group_by_field_id = FieldId(query_info_proto.group_by_field_id());
search_info.group_by_field_id_ = group_by_field_id; search_info.group_by_field_id_ = group_by_field_id;
search_info.group_size_ = query_info_proto.group_size() > 0
? query_info_proto.group_size()
: 1;
} }
auto plan_node = [&]() -> std::unique_ptr<VectorPlanNode> { auto plan_node = [&]() -> std::unique_ptr<VectorPlanNode> {
if (anns_proto.vector_type() == if (anns_proto.vector_type() ==
milvus::proto::plan::VectorType::BinaryVector) { milvus::proto::plan::VectorType::BinaryVector) {

View File

@ -10,7 +10,7 @@
// or implied. See the License for the specific language governing permissions and limitations under the License // or implied. See the License for the specific language governing permissions and limitations under the License
#include "SearchOnIndex.h" #include "SearchOnIndex.h"
#include "query/GroupByOperator.h" #include "query/groupby/SearchGroupByOperator.h"
namespace milvus::query { namespace milvus::query {
void void

View File

@ -17,7 +17,7 @@
#include "query/SearchBruteForce.h" #include "query/SearchBruteForce.h"
#include "query/SearchOnSealed.h" #include "query/SearchOnSealed.h"
#include "query/helper.h" #include "query/helper.h"
#include "query/GroupByOperator.h" #include "query/groupby/SearchGroupByOperator.h"
namespace milvus::query { namespace milvus::query {

View File

@ -71,4 +71,11 @@ out_of_range(int64_t t) {
return gt_ub<T>(t) || lt_lb<T>(t); return gt_ub<T>(t) || lt_lb<T>(t);
} }
inline bool
dis_closer(float dis1, float dis2, const MetricType& metric_type) {
if (PositivelyRelated(metric_type))
return dis1 > dis2;
return dis1 < dis2;
}
} // namespace milvus::query } // namespace milvus::query

View File

@ -13,35 +13,43 @@
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and // See the License for the specific language governing permissions and
// limitations under the License. // limitations under the License.
#include "GroupByOperator.h" #include "SearchGroupByOperator.h"
#include "common/Consts.h" #include "common/Consts.h"
#include "segcore/SegmentSealedImpl.h" #include "segcore/SegmentSealedImpl.h"
#include "Utils.h" #include "query/Utils.h"
namespace milvus { namespace milvus {
namespace query { namespace query {
void void
GroupBy(const std::vector<std::shared_ptr<VectorIterator>>& iterators, SearchGroupBy(const std::vector<std::shared_ptr<VectorIterator>>& iterators,
const SearchInfo& search_info, const SearchInfo& search_info,
std::vector<GroupByValueType>& group_by_values, std::vector<GroupByValueType>& group_by_values,
const segcore::SegmentInternalInterface& segment, const segcore::SegmentInternalInterface& segment,
std::vector<int64_t>& seg_offsets, std::vector<int64_t>& seg_offsets,
std::vector<float>& distances) { std::vector<float>& distances,
std::vector<size_t>& topk_per_nq_prefix_sum) {
//1. get search meta //1. get search meta
FieldId group_by_field_id = search_info.group_by_field_id_.value(); FieldId group_by_field_id = search_info.group_by_field_id_.value();
auto data_type = segment.GetFieldDataType(group_by_field_id); auto data_type = segment.GetFieldDataType(group_by_field_id);
int max_total_size =
search_info.topk_ * search_info.group_size_ * iterators.size();
seg_offsets.reserve(max_total_size);
distances.reserve(max_total_size);
group_by_values.reserve(max_total_size);
topk_per_nq_prefix_sum.reserve(iterators.size() + 1);
switch (data_type) { switch (data_type) {
case DataType::INT8: { case DataType::INT8: {
auto dataGetter = GetDataGetter<int8_t>(segment, group_by_field_id); auto dataGetter = GetDataGetter<int8_t>(segment, group_by_field_id);
GroupIteratorsByType<int8_t>(iterators, GroupIteratorsByType<int8_t>(iterators,
search_info.topk_, search_info.topk_,
search_info.group_size_,
*dataGetter, *dataGetter,
group_by_values, group_by_values,
seg_offsets, seg_offsets,
distances, distances,
search_info.metric_type_); search_info.metric_type_,
topk_per_nq_prefix_sum);
break; break;
} }
case DataType::INT16: { case DataType::INT16: {
@ -49,11 +57,13 @@ GroupBy(const std::vector<std::shared_ptr<VectorIterator>>& iterators,
GetDataGetter<int16_t>(segment, group_by_field_id); GetDataGetter<int16_t>(segment, group_by_field_id);
GroupIteratorsByType<int16_t>(iterators, GroupIteratorsByType<int16_t>(iterators,
search_info.topk_, search_info.topk_,
search_info.group_size_,
*dataGetter, *dataGetter,
group_by_values, group_by_values,
seg_offsets, seg_offsets,
distances, distances,
search_info.metric_type_); search_info.metric_type_,
topk_per_nq_prefix_sum);
break; break;
} }
case DataType::INT32: { case DataType::INT32: {
@ -61,11 +71,13 @@ GroupBy(const std::vector<std::shared_ptr<VectorIterator>>& iterators,
GetDataGetter<int32_t>(segment, group_by_field_id); GetDataGetter<int32_t>(segment, group_by_field_id);
GroupIteratorsByType<int32_t>(iterators, GroupIteratorsByType<int32_t>(iterators,
search_info.topk_, search_info.topk_,
search_info.group_size_,
*dataGetter, *dataGetter,
group_by_values, group_by_values,
seg_offsets, seg_offsets,
distances, distances,
search_info.metric_type_); search_info.metric_type_,
topk_per_nq_prefix_sum);
break; break;
} }
case DataType::INT64: { case DataType::INT64: {
@ -73,22 +85,26 @@ GroupBy(const std::vector<std::shared_ptr<VectorIterator>>& iterators,
GetDataGetter<int64_t>(segment, group_by_field_id); GetDataGetter<int64_t>(segment, group_by_field_id);
GroupIteratorsByType<int64_t>(iterators, GroupIteratorsByType<int64_t>(iterators,
search_info.topk_, search_info.topk_,
search_info.group_size_,
*dataGetter, *dataGetter,
group_by_values, group_by_values,
seg_offsets, seg_offsets,
distances, distances,
search_info.metric_type_); search_info.metric_type_,
topk_per_nq_prefix_sum);
break; break;
} }
case DataType::BOOL: { case DataType::BOOL: {
auto dataGetter = GetDataGetter<bool>(segment, group_by_field_id); auto dataGetter = GetDataGetter<bool>(segment, group_by_field_id);
GroupIteratorsByType<bool>(iterators, GroupIteratorsByType<bool>(iterators,
search_info.topk_, search_info.topk_,
search_info.group_size_,
*dataGetter, *dataGetter,
group_by_values, group_by_values,
seg_offsets, seg_offsets,
distances, distances,
search_info.metric_type_); search_info.metric_type_,
topk_per_nq_prefix_sum);
break; break;
} }
case DataType::VARCHAR: { case DataType::VARCHAR: {
@ -96,11 +112,13 @@ GroupBy(const std::vector<std::shared_ptr<VectorIterator>>& iterators,
GetDataGetter<std::string>(segment, group_by_field_id); GetDataGetter<std::string>(segment, group_by_field_id);
GroupIteratorsByType<std::string>(iterators, GroupIteratorsByType<std::string>(iterators,
search_info.topk_, search_info.topk_,
search_info.group_size_,
*dataGetter, *dataGetter,
group_by_values, group_by_values,
seg_offsets, seg_offsets,
distances, distances,
search_info.metric_type_); search_info.metric_type_,
topk_per_nq_prefix_sum);
break; break;
} }
default: { default: {
@ -117,19 +135,24 @@ void
GroupIteratorsByType( GroupIteratorsByType(
const std::vector<std::shared_ptr<VectorIterator>>& iterators, const std::vector<std::shared_ptr<VectorIterator>>& iterators,
int64_t topK, int64_t topK,
int64_t group_size,
const DataGetter<T>& data_getter, const DataGetter<T>& data_getter,
std::vector<GroupByValueType>& group_by_values, std::vector<GroupByValueType>& group_by_values,
std::vector<int64_t>& seg_offsets, std::vector<int64_t>& seg_offsets,
std::vector<float>& distances, std::vector<float>& distances,
const knowhere::MetricType& metrics_type) { const knowhere::MetricType& metrics_type,
std::vector<size_t>& topk_per_nq_prefix_sum) {
topk_per_nq_prefix_sum.push_back(0);
for (auto& iterator : iterators) { for (auto& iterator : iterators) {
GroupIteratorResult<T>(iterator, GroupIteratorResult<T>(iterator,
topK, topK,
group_size,
data_getter, data_getter,
group_by_values, group_by_values,
seg_offsets, seg_offsets,
distances, distances,
metrics_type); metrics_type);
topk_per_nq_prefix_sum.push_back(seg_offsets.size());
} }
} }
@ -137,23 +160,20 @@ template <typename T>
void void
GroupIteratorResult(const std::shared_ptr<VectorIterator>& iterator, GroupIteratorResult(const std::shared_ptr<VectorIterator>& iterator,
int64_t topK, int64_t topK,
int64_t group_size,
const DataGetter<T>& data_getter, const DataGetter<T>& data_getter,
std::vector<GroupByValueType>& group_by_values, std::vector<GroupByValueType>& group_by_values,
std::vector<int64_t>& offsets, std::vector<int64_t>& offsets,
std::vector<float>& distances, std::vector<float>& distances,
const knowhere::MetricType& metrics_type) { const knowhere::MetricType& metrics_type) {
//1. //1.
std::unordered_map<T, std::pair<int64_t, float>> groupMap; GroupByMap<T> groupMap(topK, group_size);
//2. do iteration until fill the whole map or run out of all data //2. do iteration until fill the whole map or run out of all data
//note it may enumerate all data inside a segment and can block following //note it may enumerate all data inside a segment and can block following
//query and search possibly //query and search possibly
auto dis_closer = [&](float l, float r) { std::vector<std::tuple<int64_t, float, T>> res;
if (PositivelyRelated(metrics_type)) while (iterator->HasNext() && !groupMap.IsGroupResEnough()) {
return l > r;
return l < r;
};
while (iterator->HasNext() && groupMap.size() < topK) {
auto offset_dis_pair = iterator->Next(); auto offset_dis_pair = iterator->Next();
AssertInfo( AssertInfo(
offset_dis_pair.has_value(), offset_dis_pair.has_value(),
@ -162,38 +182,22 @@ GroupIteratorResult(const std::shared_ptr<VectorIterator>& iterator,
auto offset = offset_dis_pair.value().first; auto offset = offset_dis_pair.value().first;
auto dis = offset_dis_pair.value().second; auto dis = offset_dis_pair.value().second;
T row_data = data_getter.Get(offset); T row_data = data_getter.Get(offset);
auto it = groupMap.find(row_data); if (groupMap.Push(row_data)) {
if (it == groupMap.end()) { res.emplace_back(offset, dis, row_data);
groupMap.emplace(row_data, std::make_pair(offset, dis));
} else if (dis_closer(dis, it->second.second)) {
it->second = {offset, dis};
} }
} }
//3. sorted based on distances and metrics //3. sorted based on distances and metrics
std::vector<std::pair<T, std::pair<int64_t, float>>> sortedGroupVals(
groupMap.begin(), groupMap.end());
auto customComparator = [&](const auto& lhs, const auto& rhs) { auto customComparator = [&](const auto& lhs, const auto& rhs) {
return dis_closer(lhs.second.second, rhs.second.second); return dis_closer(std::get<1>(lhs), std::get<1>(rhs), metrics_type);
}; };
std::sort(sortedGroupVals.begin(), sortedGroupVals.end(), customComparator); std::sort(res.begin(), res.end(), customComparator);
//4. save groupBy results //4. save groupBy results
group_by_values.reserve(sortedGroupVals.size()); for (auto iter = res.cbegin(); iter != res.cend(); iter++) {
offsets.reserve(sortedGroupVals.size()); offsets.push_back(std::get<0>(*iter));
distances.reserve(sortedGroupVals.size()); distances.push_back(std::get<1>(*iter));
for (auto iter = sortedGroupVals.cbegin(); iter != sortedGroupVals.cend(); group_by_values.emplace_back(std::move(std::get<2>(*iter)));
iter++) {
group_by_values.emplace_back(iter->first);
offsets.push_back(iter->second.first);
distances.push_back(iter->second.second);
}
//5. padding topK results, extra memory consumed will be removed when reducing
for (std::size_t idx = groupMap.size(); idx < topK; idx++) {
offsets.push_back(INVALID_SEG_OFFSET);
distances.push_back(0.0);
group_by_values.emplace_back(std::monostate{});
} }
} }

View File

@ -23,6 +23,7 @@
#include "segcore/SegmentSealedImpl.h" #include "segcore/SegmentSealedImpl.h"
#include "segcore/ConcurrentVector.h" #include "segcore/ConcurrentVector.h"
#include "common/Span.h" #include "common/Span.h"
#include "query/Utils.h"
namespace milvus { namespace milvus {
namespace query { namespace query {
@ -167,28 +168,67 @@ PrepareVectorIteratorsFromIndex(const SearchInfo& search_info,
} }
void void
GroupBy(const std::vector<std::shared_ptr<VectorIterator>>& iterators, SearchGroupBy(const std::vector<std::shared_ptr<VectorIterator>>& iterators,
const SearchInfo& searchInfo, const SearchInfo& searchInfo,
std::vector<GroupByValueType>& group_by_values, std::vector<GroupByValueType>& group_by_values,
const segcore::SegmentInternalInterface& segment, const segcore::SegmentInternalInterface& segment,
std::vector<int64_t>& seg_offsets, std::vector<int64_t>& seg_offsets,
std::vector<float>& distances); std::vector<float>& distances,
std::vector<size_t>& topk_per_nq_prefix_sum);
template <typename T> template <typename T>
void void
GroupIteratorsByType( GroupIteratorsByType(
const std::vector<std::shared_ptr<VectorIterator>>& iterators, const std::vector<std::shared_ptr<VectorIterator>>& iterators,
int64_t topK, int64_t topK,
int64_t group_size,
const DataGetter<T>& data_getter, const DataGetter<T>& data_getter,
std::vector<GroupByValueType>& group_by_values, std::vector<GroupByValueType>& group_by_values,
std::vector<int64_t>& seg_offsets, std::vector<int64_t>& seg_offsets,
std::vector<float>& distances, std::vector<float>& distances,
const knowhere::MetricType& metrics_type); const knowhere::MetricType& metrics_type,
std::vector<size_t>& topk_per_nq_prefix_sum);
template <typename T>
struct GroupByMap {
private:
std::unordered_map<T, int> group_map_{};
int group_capacity_{0};
int group_size_{0};
int enough_group_count{0};
public:
GroupByMap(int group_capacity, int group_size)
: group_capacity_(group_capacity), group_size_(group_size){};
bool
IsGroupResEnough() {
return group_map_.size() == group_capacity_ &&
enough_group_count == group_capacity_;
}
bool
Push(const T& t) {
if (group_map_.size() >= group_capacity_ && group_map_[t] == 0){
return false;
}
if (group_map_[t] >= group_size_) {
//we ignore following input no matter the distance as knowhere::iterator doesn't guarantee
//strictly increase/decreasing distance output
//but this should not be a very serious influence to overall recall rate
return false;
}
group_map_[t] += 1;
if (group_map_[t] >= group_size_) {
enough_group_count += 1;
}
return true;
}
};
template <typename T> template <typename T>
void void
GroupIteratorResult(const std::shared_ptr<VectorIterator>& iterator, GroupIteratorResult(const std::shared_ptr<VectorIterator>& iterator,
int64_t topK, int64_t topK,
int64_t group_size,
const DataGetter<T>& data_getter, const DataGetter<T>& data_getter,
std::vector<GroupByValueType>& group_by_values, std::vector<GroupByValueType>& group_by_values,
std::vector<int64_t>& offsets, std::vector<int64_t>& offsets,

View File

@ -25,7 +25,7 @@
#include "plan/PlanNode.h" #include "plan/PlanNode.h"
#include "exec/Task.h" #include "exec/Task.h"
#include "segcore/SegmentInterface.h" #include "segcore/SegmentInterface.h"
#include "query/GroupByOperator.h" #include "query/groupby/SearchGroupByOperator.h"
namespace milvus::query { namespace milvus::query {
namespace impl { namespace impl {
@ -193,14 +193,20 @@ ExecPlanNodeVisitor::VectorVisitorImpl(VectorPlanNode& node) {
search_result); search_result);
search_result.total_data_cnt_ = final_view.size(); search_result.total_data_cnt_ = final_view.size();
if (search_result.vector_iterators_.has_value()) { if (search_result.vector_iterators_.has_value()) {
AssertInfo(search_result.vector_iterators_.value().size() ==
search_result.total_nq_,
"Vector Iterators' count must be equal to total_nq_, Check "
"your code");
std::vector<GroupByValueType> group_by_values; std::vector<GroupByValueType> group_by_values;
GroupBy(search_result.vector_iterators_.value(), SearchGroupBy(search_result.vector_iterators_.value(),
node.search_info_, node.search_info_,
group_by_values, group_by_values,
*segment, *segment,
search_result.seg_offsets_, search_result.seg_offsets_,
search_result.distances_); search_result.distances_,
search_result.topk_per_nq_prefix_sum_);
search_result.group_by_values_ = std::move(group_by_values); search_result.group_by_values_ = std::move(group_by_values);
search_result.group_size_ = node.search_info_.group_size_;
AssertInfo(search_result.seg_offsets_.size() == AssertInfo(search_result.seg_offsets_.size() ==
search_result.group_by_values_.value().size(), search_result.group_by_values_.value().size(),
"Wrong state! search_result group_by_values_ size:{} is not " "Wrong state! search_result group_by_values_ size:{} is not "

View File

@ -24,8 +24,6 @@ set(SEGCORE_FILES
SegmentGrowingImpl.cpp SegmentGrowingImpl.cpp
SegmentSealedImpl.cpp SegmentSealedImpl.cpp
FieldIndexing.cpp FieldIndexing.cpp
Reduce.cpp
StreamReduce.cpp
metrics_c.cpp metrics_c.cpp
plan_c.cpp plan_c.cpp
reduce_c.cpp reduce_c.cpp
@ -39,7 +37,10 @@ set(SEGCORE_FILES
Utils.cpp Utils.cpp
ConcurrentVector.cpp ConcurrentVector.cpp
ReduceUtils.cpp ReduceUtils.cpp
check_vec_index_c.cpp) check_vec_index_c.cpp
reduce/Reduce.cpp
reduce/StreamReduce.cpp
reduce/GroupReduce.cpp)
add_library(milvus_segcore SHARED ${SEGCORE_FILES}) add_library(milvus_segcore SHARED ${SEGCORE_FILES})
target_link_libraries(milvus_segcore milvus_query milvus_bitset milvus_exec ${OpenMP_CXX_FLAGS} milvus-storage milvus_futures) target_link_libraries(milvus_segcore milvus_query milvus_bitset milvus_exec ${OpenMP_CXX_FLAGS} milvus-storage milvus_futures)

View File

@ -0,0 +1,193 @@
// Copyright (C) 2019-2020 Zilliz. All rights reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software distributed under the License
// is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express
// or implied. See the License for the specific language governing permissions and limitations under the License
#include "GroupReduce.h"
#include "log/Log.h"
#include "segcore/SegmentInterface.h"
#include "segcore/ReduceUtils.h"
namespace milvus::segcore {
void
GroupReduceHelper::FillOtherData(
int result_count,
int64_t nq_begin,
int64_t nq_end,
std::unique_ptr<milvus::proto::schema::SearchResultData>& search_res_data) {
std::vector<GroupByValueType> group_by_values;
group_by_values.resize(result_count);
for (auto qi = nq_begin; qi < nq_end; qi++) {
for (auto search_result : search_results_) {
AssertInfo(search_result != nullptr,
"null search result when reorganize");
if (search_result->result_offsets_.size() == 0) {
continue;
}
auto topk_start = search_result->topk_per_nq_prefix_sum_[qi];
auto topk_end = search_result->topk_per_nq_prefix_sum_[qi + 1];
for (auto ki = topk_start; ki < topk_end; ki++) {
auto loc = search_result->result_offsets_[ki];
group_by_values[loc] =
search_result->group_by_values_.value()[ki];
}
}
}
AssembleGroupByValues(search_res_data, group_by_values, plan_);
}
void
GroupReduceHelper::RefreshSingleSearchResult(SearchResult* search_result,
int seg_res_idx,
std::vector<int64_t>& real_topks) {
AssertInfo(search_result->group_by_values_.has_value(),
"no group by values for search result, group reducer should not "
"be called, wrong code");
AssertInfo(search_result->primary_keys_.size() ==
search_result->group_by_values_.value().size(),
"Wrong size for group_by_values size before refresh:{}, "
"not equal to "
"primary_keys_.size:{}",
search_result->group_by_values_.value().size(),
search_result->primary_keys_.size());
uint32_t size = 0;
for (int j = 0; j < total_nq_; j++) {
size += final_search_records_[seg_res_idx][j].size();
}
std::vector<milvus::PkType> primary_keys(size);
std::vector<float> distances(size);
std::vector<int64_t> seg_offsets(size);
std::vector<GroupByValueType> group_by_values(size);
uint32_t index = 0;
for (int j = 0; j < total_nq_; j++) {
for (auto offset : final_search_records_[seg_res_idx][j]) {
primary_keys[index] = search_result->primary_keys_[offset];
distances[index] = search_result->distances_[offset];
seg_offsets[index] = search_result->seg_offsets_[offset];
group_by_values[index] =
search_result->group_by_values_.value()[offset];
index++;
real_topks[j]++;
}
}
search_result->primary_keys_.swap(primary_keys);
search_result->distances_.swap(distances);
search_result->seg_offsets_.swap(seg_offsets);
search_result->group_by_values_.value().swap(group_by_values);
AssertInfo(search_result->primary_keys_.size() ==
search_result->group_by_values_.value().size(),
"Wrong size for group_by_values size after refresh:{}, "
"not equal to "
"primary_keys_.size:{}",
search_result->group_by_values_.value().size(),
search_result->primary_keys_.size());
}
void
GroupReduceHelper::FilterInvalidSearchResult(SearchResult* search_result) {
//do nothing, for group-by reduce, as we calculate prefix_sum for nq when doing group by and no padding invalid results
//so there's no need to filter search_result
}
int64_t
GroupReduceHelper::ReduceSearchResultForOneNQ(int64_t qi,
int64_t topk,
int64_t& offset) {
std::priority_queue<SearchResultPair*,
std::vector<SearchResultPair*>,
SearchResultPairComparator>
heap;
pk_set_.clear();
pairs_.clear();
pairs_.reserve(num_segments_);
for (int i = 0; i < num_segments_; i++) {
auto search_result = search_results_[i];
auto offset_beg = search_result->topk_per_nq_prefix_sum_[qi];
auto offset_end = search_result->topk_per_nq_prefix_sum_[qi + 1];
if (offset_beg == offset_end) {
continue;
}
auto primary_key = search_result->primary_keys_[offset_beg];
auto distance = search_result->distances_[offset_beg];
AssertInfo(search_result->group_by_values_.has_value(),
"Wrong state, search_result has no group_by_vales for "
"group_by_reduce, must be sth wrong!");
AssertInfo(search_result->group_by_values_.value().size() ==
search_result->primary_keys_.size(),
"Wrong state, search_result's group_by_values's length is "
"not equal to pks' size!");
auto group_by_val = search_result->group_by_values_.value()[offset_beg];
pairs_.emplace_back(primary_key,
distance,
search_result,
i,
offset_beg,
offset_end,
std::move(group_by_val));
heap.push(&pairs_.back());
}
// nq has no results for all segments
if (heap.size() == 0) {
return 0;
}
int64_t group_size = search_results_[0]->group_size_.value();
int64_t group_by_total_size = group_size * topk;
int64_t filtered_count = 0;
auto start = offset;
std::unordered_map<GroupByValueType, int64_t> group_by_map;
auto should_filtered = [&](const PkType& pk,
const GroupByValueType& group_by_val) {
if (pk_set_.count(pk) != 0)
return true;
if (group_by_map.size() >= topk &&
group_by_map.count(group_by_val) == 0)
return true;
if (group_by_map[group_by_val] >= group_size)
return true;
return false;
};
while (offset - start < group_by_total_size && !heap.empty()) {
//fetch value
auto pilot = heap.top();
heap.pop();
auto index = pilot->segment_index_;
auto pk = pilot->primary_key_;
AssertInfo(pk != INVALID_PK,
"Wrong, search results should have been filtered and "
"invalid_pk should not be existed");
auto group_by_val = pilot->group_by_value_.value();
//judge filter
if (!should_filtered(pk, group_by_val)) {
pilot->search_result_->result_offsets_.push_back(offset++);
final_search_records_[index][qi].push_back(pilot->offset_);
pk_set_.insert(pk);
group_by_map[group_by_val] += 1;
} else {
filtered_count++;
}
//move pilot forward
pilot->advance();
if (pilot->primary_key_ != INVALID_PK) {
heap.push(pilot);
}
}
return filtered_count;
}
} // namespace milvus::segcore

View File

@ -0,0 +1,58 @@
// Copyright (C) 2019-2020 Zilliz. All rights reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software distributed under the License
// is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express
// or implied. See the License for the specific language governing permissions and limitations under the License
#pragma once
#include "Reduce.h"
#include "common/QueryResult.h"
#include "query/PlanImpl.h"
namespace milvus::segcore {
class GroupReduceHelper : public ReduceHelper {
public:
explicit GroupReduceHelper(std::vector<SearchResult*>& search_results,
milvus::query::Plan* plan,
int64_t* slice_nqs,
int64_t* slice_topKs,
int64_t slice_num,
tracer::TraceContext* trace_ctx)
: ReduceHelper(search_results,
plan,
slice_nqs,
slice_topKs,
slice_num,
trace_ctx) {
}
protected:
void
FilterInvalidSearchResult(SearchResult* search_result) override;
int64_t
ReduceSearchResultForOneNQ(int64_t qi,
int64_t topk,
int64_t& result_offset) override;
void
RefreshSingleSearchResult(SearchResult* search_result,
int seg_res_idx,
std::vector<int64_t>& real_topks) override;
void
FillOtherData(int result_count,
int64_t nq_begin,
int64_t nq_end,
std::unique_ptr<milvus::proto::schema::SearchResultData>&
search_res_data) override;
private:
std::unordered_set<milvus::GroupByValueType> group_by_val_set_{};
};
} // namespace milvus::segcore

View File

@ -11,15 +11,15 @@
#include "Reduce.h" #include "Reduce.h"
#include <log/Log.h> #include "log/Log.h"
#include <cstdint> #include <cstdint>
#include <vector> #include <vector>
#include "SegmentInterface.h" #include "segcore/SegmentInterface.h"
#include "Utils.h" #include "segcore/Utils.h"
#include "common/EasyAssert.h" #include "common/EasyAssert.h"
#include "pkVisitor.h" #include "segcore/pkVisitor.h"
#include "ReduceUtils.h" #include "segcore/ReduceUtils.h"
namespace milvus::segcore { namespace milvus::segcore {
@ -56,7 +56,7 @@ void
ReduceHelper::Reduce() { ReduceHelper::Reduce() {
FillPrimaryKey(); FillPrimaryKey();
ReduceResultData(); ReduceResultData();
RefreshSearchResult(); RefreshSearchResults();
FillEntryData(); FillEntryData();
} }
@ -90,13 +90,6 @@ ReduceHelper::FilterInvalidSearchResult(SearchResult* search_result) {
auto segment = static_cast<SegmentInterface*>(search_result->segment_); auto segment = static_cast<SegmentInterface*>(search_result->segment_);
auto& offsets = search_result->seg_offsets_; auto& offsets = search_result->seg_offsets_;
auto& distances = search_result->distances_; auto& distances = search_result->distances_;
if (search_result->group_by_values_.has_value()) {
AssertInfo(search_result->distances_.size() ==
search_result->group_by_values_.value().size(),
"wrong group_by_values size, size:{}, expected size:{} ",
search_result->group_by_values_.value().size(),
search_result->distances_.size());
}
for (auto i = 0; i < nq; ++i) { for (auto i = 0; i < nq; ++i) {
for (auto j = 0; j < topK; ++j) { for (auto j = 0; j < topK; ++j) {
@ -112,18 +105,12 @@ ReduceHelper::FilterInvalidSearchResult(SearchResult* search_result) {
real_topks[i]++; real_topks[i]++;
offsets[valid_index] = offsets[index]; offsets[valid_index] = offsets[index];
distances[valid_index] = distances[index]; distances[valid_index] = distances[index];
if (search_result->group_by_values_.has_value())
search_result->group_by_values_.value()[valid_index] =
search_result->group_by_values_.value()[index];
valid_index++; valid_index++;
} }
} }
} }
offsets.resize(valid_index); offsets.resize(valid_index);
distances.resize(valid_index); distances.resize(valid_index);
if (search_result->group_by_values_.has_value())
search_result->group_by_values_.value().resize(valid_index);
search_result->topk_per_nq_prefix_sum_.resize(nq + 1); search_result->topk_per_nq_prefix_sum_.resize(nq + 1);
std::partial_sum(real_topks.begin(), std::partial_sum(real_topks.begin(),
real_topks.end(), real_topks.end(),
@ -154,59 +141,14 @@ ReduceHelper::FillPrimaryKey() {
} }
void void
ReduceHelper::RefreshSearchResult() { ReduceHelper::RefreshSearchResults() {
tracer::AutoSpan span( tracer::AutoSpan span(
"ReduceHelper::RefreshSearchResult", trace_ctx_, false); "ReduceHelper::RefreshSearchResults", trace_ctx_, false);
for (int i = 0; i < num_segments_; i++) { for (int i = 0; i < num_segments_; i++) {
std::vector<int64_t> real_topks(total_nq_, 0); std::vector<int64_t> real_topks(total_nq_, 0);
auto search_result = search_results_[i]; auto search_result = search_results_[i];
if (search_result->group_by_values_.has_value()) {
AssertInfo(search_result->primary_keys_.size() ==
search_result->group_by_values_.value().size(),
"Wrong size for group_by_values size before refresh:{}, "
"not equal to "
"primary_keys_.size:{}",
search_result->group_by_values_.value().size(),
search_result->primary_keys_.size());
}
if (search_result->result_offsets_.size() != 0) { if (search_result->result_offsets_.size() != 0) {
uint32_t size = 0; RefreshSingleSearchResult(search_result, i, real_topks);
for (int j = 0; j < total_nq_; j++) {
size += final_search_records_[i][j].size();
}
std::vector<milvus::PkType> primary_keys(size);
std::vector<float> distances(size);
std::vector<int64_t> seg_offsets(size);
std::vector<GroupByValueType> group_by_values(size);
uint32_t index = 0;
for (int j = 0; j < total_nq_; j++) {
for (auto offset : final_search_records_[i][j]) {
primary_keys[index] = search_result->primary_keys_[offset];
distances[index] = search_result->distances_[offset];
seg_offsets[index] = search_result->seg_offsets_[offset];
if (search_result->group_by_values_.has_value())
group_by_values[index] =
search_result->group_by_values_.value()[offset];
index++;
real_topks[j]++;
}
}
search_result->primary_keys_.swap(primary_keys);
search_result->distances_.swap(distances);
search_result->seg_offsets_.swap(seg_offsets);
if (search_result->group_by_values_.has_value()) {
search_result->group_by_values_.value().swap(group_by_values);
}
}
if (search_result->group_by_values_.has_value()) {
AssertInfo(search_result->primary_keys_.size() ==
search_result->group_by_values_.value().size(),
"Wrong size for group_by_values size after refresh:{}, "
"not equal to "
"primary_keys_.size:{}",
search_result->group_by_values_.value().size(),
search_result->primary_keys_.size());
} }
std::partial_sum(real_topks.begin(), std::partial_sum(real_topks.begin(),
real_topks.end(), real_topks.end(),
@ -214,6 +156,33 @@ ReduceHelper::RefreshSearchResult() {
} }
} }
void
ReduceHelper::RefreshSingleSearchResult(SearchResult* search_result,
int seg_res_idx,
std::vector<int64_t>& real_topks) {
uint32_t size = 0;
for (int j = 0; j < total_nq_; j++) {
size += final_search_records_[seg_res_idx][j].size();
}
std::vector<milvus::PkType> primary_keys(size);
std::vector<float> distances(size);
std::vector<int64_t> seg_offsets(size);
uint32_t index = 0;
for (int j = 0; j < total_nq_; j++) {
for (auto offset : final_search_records_[seg_res_idx][j]) {
primary_keys[index] = search_result->primary_keys_[offset];
distances[index] = search_result->distances_[offset];
seg_offsets[index] = search_result->seg_offsets_[offset];
index++;
real_topks[j]++;
}
}
search_result->primary_keys_.swap(primary_keys);
search_result->distances_.swap(distances);
search_result->seg_offsets_.swap(seg_offsets);
}
void void
ReduceHelper::FillEntryData() { ReduceHelper::FillEntryData() {
tracer::AutoSpan span("ReduceHelper::FillEntryData", trace_ctx_, false); tracer::AutoSpan span("ReduceHelper::FillEntryData", trace_ctx_, false);
@ -228,12 +197,12 @@ int64_t
ReduceHelper::ReduceSearchResultForOneNQ(int64_t qi, ReduceHelper::ReduceSearchResultForOneNQ(int64_t qi,
int64_t topk, int64_t topk,
int64_t& offset) { int64_t& offset) {
while (!heap_.empty()) { std::priority_queue<SearchResultPair*,
heap_.pop(); std::vector<SearchResultPair*>,
} SearchResultPairComparator>
heap;
pk_set_.clear(); pk_set_.clear();
pairs_.clear(); pairs_.clear();
group_by_val_set_.clear();
pairs_.reserve(num_segments_); pairs_.reserve(num_segments_);
for (int i = 0; i < num_segments_; i++) { for (int i = 0; i < num_segments_; i++) {
@ -245,41 +214,21 @@ ReduceHelper::ReduceSearchResultForOneNQ(int64_t qi,
} }
auto primary_key = search_result->primary_keys_[offset_beg]; auto primary_key = search_result->primary_keys_[offset_beg];
auto distance = search_result->distances_[offset_beg]; auto distance = search_result->distances_[offset_beg];
if (search_result->group_by_values_.has_value()) {
AssertInfo(
search_result->group_by_values_.value().size() > offset_beg,
"Wrong size for group_by_values size to "
"ReduceSearchResultForOneNQ:{}, not enough for"
"required offset_beg:{}",
search_result->group_by_values_.value().size(),
offset_beg);
}
pairs_.emplace_back( pairs_.emplace_back(
primary_key, primary_key, distance, search_result, i, offset_beg, offset_end);
distance, heap.push(&pairs_.back());
search_result,
i,
offset_beg,
offset_end,
search_result->group_by_values_.has_value() &&
search_result->group_by_values_.value().size() > offset_beg
? std::make_optional(
search_result->group_by_values_.value().at(offset_beg))
: std::nullopt);
heap_.push(&pairs_.back());
} }
// nq has no results for all segments // nq has no results for all segments
if (heap_.size() == 0) { if (heap.size() == 0) {
return 0; return 0;
} }
int64_t dup_cnt = 0; int64_t dup_cnt = 0;
auto start = offset; auto start = offset;
while (offset - start < topk && !heap_.empty()) { while (offset - start < topk && !heap.empty()) {
auto pilot = heap_.top(); auto pilot = heap.top();
heap_.pop(); heap.pop();
auto index = pilot->segment_index_; auto index = pilot->segment_index_;
auto pk = pilot->primary_key_; auto pk = pilot->primary_key_;
@ -289,27 +238,16 @@ ReduceHelper::ReduceSearchResultForOneNQ(int64_t qi,
} }
// remove duplicates // remove duplicates
if (pk_set_.count(pk) == 0) { if (pk_set_.count(pk) == 0) {
bool skip_for_group_by = false; pilot->search_result_->result_offsets_.push_back(offset++);
if (pilot->group_by_value_.has_value()) { final_search_records_[index][qi].push_back(pilot->offset_);
if (group_by_val_set_.count(pilot->group_by_value_.value()) > pk_set_.insert(pk);
0) {
skip_for_group_by = true;
}
}
if (!skip_for_group_by) {
pilot->search_result_->result_offsets_.push_back(offset++);
final_search_records_[index][qi].push_back(pilot->offset_);
pk_set_.insert(pk);
if (pilot->group_by_value_.has_value())
group_by_val_set_.insert(pilot->group_by_value_.value());
}
} else { } else {
// skip entity with same primary key // skip entity with same primary key
dup_cnt++; dup_cnt++;
} }
pilot->advance(); pilot->advance();
if (pilot->primary_key_ != INVALID_PK) { if (pilot->primary_key_ != INVALID_PK) {
heap_.push(pilot); heap.push(pilot);
} }
} }
return dup_cnt; return dup_cnt;
@ -331,7 +269,7 @@ ReduceHelper::ReduceResultData() {
"incorrect search result primary key size"); "incorrect search result primary key size");
} }
int64_t skip_dup_cnt = 0; int64_t filtered_count = 0;
for (int64_t slice_index = 0; slice_index < num_slices_; slice_index++) { for (int64_t slice_index = 0; slice_index < num_slices_; slice_index++) {
auto nq_begin = slice_nqs_prefix_sum_[slice_index]; auto nq_begin = slice_nqs_prefix_sum_[slice_index];
auto nq_end = slice_nqs_prefix_sum_[slice_index + 1]; auto nq_end = slice_nqs_prefix_sum_[slice_index + 1];
@ -339,15 +277,24 @@ ReduceHelper::ReduceResultData() {
// reduce search results // reduce search results
int64_t offset = 0; int64_t offset = 0;
for (int64_t qi = nq_begin; qi < nq_end; qi++) { for (int64_t qi = nq_begin; qi < nq_end; qi++) {
skip_dup_cnt += ReduceSearchResultForOneNQ( filtered_count += ReduceSearchResultForOneNQ(
qi, slice_topKs_[slice_index], offset); qi, slice_topKs_[slice_index], offset);
} }
} }
if (skip_dup_cnt > 0) { if (filtered_count > 0) {
LOG_DEBUG("skip duplicated search result, count = {}", skip_dup_cnt); LOG_DEBUG("skip duplicated search result, count = {}", filtered_count);
} }
} }
void
ReduceHelper::FillOtherData(
int result_count,
int64_t nq_begin,
int64_t nq_end,
std::unique_ptr<milvus::proto::schema::SearchResultData>& search_res_data) {
//simple batch reduce do nothing for other data
}
std::vector<char> std::vector<char>
ReduceHelper::GetSearchResultDataSlice(int slice_index) { ReduceHelper::GetSearchResultDataSlice(int slice_index) {
auto nq_begin = slice_nqs_prefix_sum_[slice_index]; auto nq_begin = slice_nqs_prefix_sum_[slice_index];
@ -370,7 +317,6 @@ ReduceHelper::GetSearchResultDataSlice(int slice_index) {
search_result_data->set_top_k(slice_topKs_[slice_index]); search_result_data->set_top_k(slice_topKs_[slice_index]);
search_result_data->set_num_queries(nq_end - nq_begin); search_result_data->set_num_queries(nq_end - nq_begin);
search_result_data->mutable_topks()->Resize(nq_end - nq_begin, 0); search_result_data->mutable_topks()->Resize(nq_end - nq_begin, 0);
search_result_data->set_all_search_count(all_search_count); search_result_data->set_all_search_count(all_search_count);
// `result_pairs` contains the SearchResult and result_offset info, used for filling output fields // `result_pairs` contains the SearchResult and result_offset info, used for filling output fields
@ -407,12 +353,6 @@ ReduceHelper::GetSearchResultDataSlice(int slice_index) {
// reserve space for distances // reserve space for distances
search_result_data->mutable_scores()->Resize(result_count, 0); search_result_data->mutable_scores()->Resize(result_count, 0);
//reserve space for group_by_values
std::vector<GroupByValueType> group_by_values;
if (plan_->plan_node_->search_info_.group_by_field_id_.has_value()) {
group_by_values.resize(result_count);
}
// fill pks and distances // fill pks and distances
for (auto qi = nq_begin; qi < nq_end; qi++) { for (auto qi = nq_begin; qi < nq_end; qi++) {
int64_t topk_count = 0; int64_t topk_count = 0;
@ -461,11 +401,6 @@ ReduceHelper::GetSearchResultDataSlice(int slice_index) {
search_result_data->mutable_scores()->Set( search_result_data->mutable_scores()->Set(
loc, search_result->distances_[ki]); loc, search_result->distances_[ki]);
// set group by values
if (search_result->group_by_values_.has_value() &&
ki < search_result->group_by_values_.value().size())
group_by_values[loc] =
search_result->group_by_values_.value()[ki];
// set result offset to fill output fields data // set result offset to fill output fields data
result_pairs[loc] = {&search_result->output_fields_data_, ki}; result_pairs[loc] = {&search_result->output_fields_data_, ki};
} }
@ -474,12 +409,13 @@ ReduceHelper::GetSearchResultDataSlice(int slice_index) {
// update result topKs // update result topKs
search_result_data->mutable_topks()->Set(qi - nq_begin, topk_count); search_result_data->mutable_topks()->Set(qi - nq_begin, topk_count);
} }
AssembleGroupByValues(search_result_data, group_by_values, plan_);
AssertInfo(search_result_data->scores_size() == result_count, AssertInfo(search_result_data->scores_size() == result_count,
"wrong scores size, size = " + "wrong scores size, size = " +
std::to_string(search_result_data->scores_size()) + std::to_string(search_result_data->scores_size()) +
", expected size = " + std::to_string(result_count)); ", expected size = " + std::to_string(result_count));
// fill other wanted data
FillOtherData(result_count, nq_begin, nq_end, search_result_data);
// set output fields // set output fields
for (auto field_id : plan_->target_entries_) { for (auto field_id : plan_->target_entries_) {

View File

@ -21,9 +21,9 @@
#include "common/type_c.h" #include "common/type_c.h"
#include "common/QueryResult.h" #include "common/QueryResult.h"
#include "query/PlanImpl.h" #include "query/PlanImpl.h"
#include "ReduceStructure.h" #include "segcore/ReduceStructure.h"
#include "common/Tracer.h" #include "common/Tracer.h"
#include "segment_c.h" #include "segcore/segment_c.h"
namespace milvus::segcore { namespace milvus::segcore {
@ -60,11 +60,16 @@ class ReduceHelper {
} }
protected: protected:
void virtual void
FilterInvalidSearchResult(SearchResult* search_result); FilterInvalidSearchResult(SearchResult* search_result);
void void
RefreshSearchResult(); RefreshSearchResults();
virtual void
RefreshSingleSearchResult(SearchResult* search_result,
int seg_res_idx,
std::vector<int64_t>& real_topks);
void void
FillPrimaryKey(); FillPrimaryKey();
@ -72,6 +77,18 @@ class ReduceHelper {
void void
ReduceResultData(); ReduceResultData();
virtual int64_t
ReduceSearchResultForOneNQ(int64_t qi,
int64_t topk,
int64_t& result_offset);
virtual void
FillOtherData(int result_count,
int64_t nq_begin,
int64_t nq_end,
std::unique_ptr<milvus::proto::schema::SearchResultData>&
search_res_data);
private: private:
void void
Initialize(); Initialize();
@ -79,11 +96,6 @@ class ReduceHelper {
void void
FillEntryData(); FillEntryData();
int64_t
ReduceSearchResultForOneNQ(int64_t qi,
int64_t topk,
int64_t& result_offset);
std::vector<char> std::vector<char>
GetSearchResultDataSlice(int slice_index_); GetSearchResultDataSlice(int slice_index_);
@ -94,25 +106,16 @@ class ReduceHelper {
std::vector<int64_t> slice_nqs_prefix_sum_; std::vector<int64_t> slice_nqs_prefix_sum_;
int64_t num_segments_; int64_t num_segments_;
std::vector<int64_t> slice_topKs_; std::vector<int64_t> slice_topKs_;
std::priority_queue<SearchResultPair*,
std::vector<SearchResultPair*>,
SearchResultPairComparator>
heap_;
// Used for merge results, // Used for merge results,
// define these here to avoid allocating them for each query // define these here to avoid allocating them for each query
std::vector<SearchResultPair> pairs_; std::vector<SearchResultPair> pairs_;
std::unordered_set<milvus::PkType> pk_set_; std::unordered_set<milvus::PkType> pk_set_;
std::unordered_set<milvus::GroupByValueType> group_by_val_set_;
// dim0: num_segments_; dim1: total_nq_; dim2: offset // dim0: num_segments_; dim1: total_nq_; dim2: offset
std::vector<std::vector<std::vector<int64_t>>> final_search_records_; std::vector<std::vector<std::vector<int64_t>>> final_search_records_;
private:
std::vector<int64_t> slice_nqs_; std::vector<int64_t> slice_nqs_;
int64_t total_nq_; int64_t total_nq_;
// output // output
std::unique_ptr<SearchResultDataBlobs> search_result_data_blobs_; std::unique_ptr<SearchResultDataBlobs> search_result_data_blobs_;
tracer::TraceContext* trace_ctx_; tracer::TraceContext* trace_ctx_;
}; };

View File

@ -10,9 +10,9 @@
// or implied. See the License for the specific language governing permissions and limitations under the License // or implied. See the License for the specific language governing permissions and limitations under the License
#include "StreamReduce.h" #include "StreamReduce.h"
#include "SegmentInterface.h" #include "segcore/SegmentInterface.h"
#include "segcore/Utils.h" #include "segcore/Utils.h"
#include "Reduce.h" #include "segcore/reduce/Reduce.h"
#include "segcore/pkVisitor.h" #include "segcore/pkVisitor.h"
#include "segcore/ReduceUtils.h" #include "segcore/ReduceUtils.h"

View File

@ -10,12 +10,13 @@
// or implied. See the License for the specific language governing permissions and limitations under the License // or implied. See the License for the specific language governing permissions and limitations under the License
#include <vector> #include <vector>
#include "Reduce.h" #include "segcore/reduce/Reduce.h"
#include "segcore/reduce/GroupReduce.h"
#include "common/QueryResult.h" #include "common/QueryResult.h"
#include "common/EasyAssert.h" #include "common/EasyAssert.h"
#include "query/Plan.h" #include "query/Plan.h"
#include "segcore/reduce_c.h" #include "segcore/reduce_c.h"
#include "segcore/StreamReduce.h" #include "segcore/reduce/StreamReduce.h"
#include "segcore/Utils.h" #include "segcore/Utils.h"
using SearchResult = milvus::SearchResult; using SearchResult = milvus::SearchResult;
@ -95,17 +96,30 @@ ReduceSearchResultsAndFillData(CTraceContext c_trace,
search_results[i] = static_cast<SearchResult*>(c_search_results[i]); search_results[i] = static_cast<SearchResult*>(c_search_results[i]);
} }
auto reduce_helper = milvus::segcore::ReduceHelper(search_results, std::shared_ptr<milvus::segcore::ReduceHelper> reduce_helper;
plan, if (plan->plan_node_->search_info_.group_by_field_id_.has_value()) {
slice_nqs, reduce_helper =
slice_topKs, std::make_shared<milvus::segcore::GroupReduceHelper>(
num_slices, search_results,
&trace_ctx); plan,
reduce_helper.Reduce(); slice_nqs,
reduce_helper.Marshal(); slice_topKs,
num_slices,
&trace_ctx);
} else {
reduce_helper =
std::make_shared<milvus::segcore::ReduceHelper>(search_results,
plan,
slice_nqs,
slice_topKs,
num_slices,
&trace_ctx);
}
reduce_helper->Reduce();
reduce_helper->Marshal();
// set final result ptr // set final result ptr
*cSearchResultDataBlobs = reduce_helper.GetSearchResultDataBlobs(); *cSearchResultDataBlobs = reduce_helper->GetSearchResultDataBlobs();
return milvus::SuccessCStatus(); return milvus::SuccessCStatus();
} catch (std::exception& e) { } catch (std::exception& e) {
return milvus::FailureCStatus(&e); return milvus::FailureCStatus(&e);

View File

@ -31,7 +31,7 @@
#include "pb/plan.pb.h" #include "pb/plan.pb.h"
#include "query/ExprImpl.h" #include "query/ExprImpl.h"
#include "segcore/Collection.h" #include "segcore/Collection.h"
#include "segcore/Reduce.h" #include "segcore/reduce/Reduce.h"
#include "segcore/reduce_c.h" #include "segcore/reduce_c.h"
#include "segcore/segment_c.h" #include "segcore/segment_c.h"
#include "futures/Future.h" #include "futures/Future.h"

View File

@ -16,7 +16,7 @@
#include "index/IndexFactory.h" #include "index/IndexFactory.h"
#include "knowhere/comp/index_param.h" #include "knowhere/comp/index_param.h"
#include "query/ExprImpl.h" #include "query/ExprImpl.h"
#include "segcore/Reduce.h" #include "segcore/reduce/Reduce.h"
#include "segcore/reduce_c.h" #include "segcore/reduce_c.h"
#include "test_utils/DataGen.h" #include "test_utils/DataGen.h"
#include "test_utils/PbHelper.h" #include "test_utils/PbHelper.h"

View File

@ -64,20 +64,13 @@ CheckGroupBySearchResult(const SearchResult& search_result,
int topK, int topK,
int nq, int nq,
bool strict) { bool strict) {
int total = topK * nq; int size = search_result.group_by_values_.value().size();
ASSERT_EQ(search_result.group_by_values_.value().size(), total); ASSERT_EQ(search_result.seg_offsets_.size(), size);
ASSERT_EQ(search_result.seg_offsets_.size(), total); ASSERT_EQ(search_result.distances_.size(), size);
ASSERT_EQ(search_result.distances_.size(), total);
ASSERT_TRUE(search_result.seg_offsets_[0] != INVALID_SEG_OFFSET); ASSERT_TRUE(search_result.seg_offsets_[0] != INVALID_SEG_OFFSET);
int res_bound = GetSearchResultBound(search_result); ASSERT_TRUE(search_result.seg_offsets_[size - 1] != INVALID_SEG_OFFSET);
ASSERT_TRUE(res_bound > 0); ASSERT_EQ(search_result.topk_per_nq_prefix_sum_.size(), nq + 1);
if (strict) { ASSERT_EQ(size, search_result.topk_per_nq_prefix_sum_[nq]);
ASSERT_TRUE(res_bound == total - 1);
} else {
ASSERT_TRUE(res_bound == total - 1 ||
search_result.seg_offsets_[res_bound + 1] ==
INVALID_SEG_OFFSET);
}
} }
TEST(GroupBY, SealedIndex) { TEST(GroupBY, SealedIndex) {
@ -98,10 +91,10 @@ TEST(GroupBY, SealedIndex) {
auto bool_fid = schema->AddDebugField("bool", DataType::BOOL); auto bool_fid = schema->AddDebugField("bool", DataType::BOOL);
schema->set_primary_field_id(str_fid); schema->set_primary_field_id(str_fid);
auto segment = CreateSealedSegment(schema); auto segment = CreateSealedSegment(schema);
size_t N = 100; size_t N = 50;
//2. load raw data //2. load raw data
auto raw_data = DataGen(schema, N); auto raw_data = DataGen(schema, N, 42, 0, 8, 10, false, false);
auto fields = schema->get_fields(); auto fields = schema->get_fields();
for (auto field_data : raw_data.raw_->fields_data()) { for (auto field_data : raw_data.raw_->fields_data()) {
int64_t field_id = field_data.field_id(); int64_t field_id = field_data.field_id();
@ -125,24 +118,27 @@ TEST(GroupBY, SealedIndex) {
load_index_info.index = std::move(indexing); load_index_info.index = std::move(indexing);
load_index_info.index_params[METRICS_TYPE] = knowhere::metric::L2; load_index_info.index_params[METRICS_TYPE] = knowhere::metric::L2;
segment->LoadIndex(load_index_info); segment->LoadIndex(load_index_info);
int topK = 100; int topK = 15;
int group_size = 3;
//4. search group by int8 //4. search group by int8
{ {
const char* raw_plan = R"(vector_anns: < const char* raw_plan = R"(vector_anns: <
field_id: 100 field_id: 100
query_info: < query_info: <
topk: 100 topk: 15
metric_type: "L2" metric_type: "L2"
search_params: "{\"ef\": 10}" search_params: "{\"ef\": 10}"
group_by_field_id: 101 group_by_field_id: 101
group_size: 3
> >
placeholder_tag: "$0" placeholder_tag: "$0"
>)"; >)";
auto plan_str = translate_text_plan_to_binary_plan(raw_plan); proto::plan::PlanNode plan_node;
auto plan = auto ok =
CreateSearchPlanByExpr(*schema, plan_str.data(), plan_str.size()); google::protobuf::TextFormat::ParseFromString(raw_plan, &plan_node);
auto plan = CreateSearchPlanFromPlanNode(*schema, plan_node);
auto num_queries = 1; auto num_queries = 1;
auto seed = 1024; auto seed = 1024;
auto ph_group_raw = CreatePlaceholderGroup(num_queries, dim, seed); auto ph_group_raw = CreatePlaceholderGroup(num_queries, dim, seed);
@ -153,29 +149,31 @@ TEST(GroupBY, SealedIndex) {
CheckGroupBySearchResult(*search_result, topK, num_queries, false); CheckGroupBySearchResult(*search_result, topK, num_queries, false);
auto& group_by_values = search_result->group_by_values_.value(); auto& group_by_values = search_result->group_by_values_.value();
ASSERT_EQ(20, group_by_values.size());
//as the total data is 0,0,....6,6, so there will be 7 buckets with [3,3,3,3,3,3,2] items respectively
//so there will be 20 items returned
int size = group_by_values.size(); int size = group_by_values.size();
std::unordered_set<int8_t> i8_set; std::unordered_map<int8_t, int> i8_map;
float lastDistance = 0.0; float lastDistance = 0.0;
for (size_t i = 0; i < size; i++) { for (size_t i = 0; i < size; i++) {
if (std::holds_alternative<int8_t>(group_by_values[i])) { if (std::holds_alternative<int8_t>(group_by_values[i])) {
int8_t g_val = std::get<int8_t>(group_by_values[i]); int8_t g_val = std::get<int8_t>(group_by_values[i]);
ASSERT_FALSE(i8_set.count(g_val) > i8_map[g_val] += 1;
0); //no repetition on groupBy field ASSERT_TRUE(i8_map[g_val] <= group_size);
i8_set.insert(g_val); //for every group, the number of hits should not exceed group_size
auto distance = search_result->distances_.at(i); auto distance = search_result->distances_.at(i);
ASSERT_TRUE( ASSERT_TRUE(
lastDistance <= lastDistance <=
distance); //distance should be decreased as metrics_type is L2 distance); //distance should be decreased as metrics_type is L2
lastDistance = distance; lastDistance = distance;
} else {
//check padding
ASSERT_EQ(search_result->seg_offsets_[i], INVALID_SEG_OFFSET);
ASSERT_EQ(search_result->distances_[i], 0.0);
} }
} }
ASSERT_TRUE(i8_map.size() <= topK);
ASSERT_TRUE(i8_map.size() == 7);
} }
//4. search group by int16 //5. search group by int16
{ {
const char* raw_plan = R"(vector_anns: < const char* raw_plan = R"(vector_anns: <
field_id: 100 field_id: 100
@ -184,14 +182,16 @@ TEST(GroupBY, SealedIndex) {
metric_type: "L2" metric_type: "L2"
search_params: "{\"ef\": 10}" search_params: "{\"ef\": 10}"
group_by_field_id: 102 group_by_field_id: 102
group_size: 3
> >
placeholder_tag: "$0" placeholder_tag: "$0"
>)"; >)";
auto plan_str = translate_text_plan_to_binary_plan(raw_plan); proto::plan::PlanNode plan_node;
auto plan = auto ok =
CreateSearchPlanByExpr(*schema, plan_str.data(), plan_str.size()); google::protobuf::TextFormat::ParseFromString(raw_plan, &plan_node);
auto plan = CreateSearchPlanFromPlanNode(*schema, plan_node);
auto num_queries = 1; auto num_queries = 1;
auto seed = 1024; auto seed = 1024;
auto ph_group_raw = CreatePlaceholderGroup(num_queries, dim, seed); auto ph_group_raw = CreatePlaceholderGroup(num_queries, dim, seed);
@ -203,28 +203,27 @@ TEST(GroupBY, SealedIndex) {
auto& group_by_values = search_result->group_by_values_.value(); auto& group_by_values = search_result->group_by_values_.value();
int size = group_by_values.size(); int size = group_by_values.size();
std::unordered_set<int16_t> i16_set; ASSERT_EQ(20, size);
//as the total data is 0,0,....6,6, so there will be 7 buckets with [3,3,3,3,3,3,2] items respectively
//so there will be 20 items returned
std::unordered_map<int16_t, int> i16_map;
float lastDistance = 0.0; float lastDistance = 0.0;
for (size_t i = 0; i < size; i++) { for (size_t i = 0; i < size; i++) {
if (std::holds_alternative<int16_t>(group_by_values[i])) { if (std::holds_alternative<int16_t>(group_by_values[i])) {
int16_t g_val = std::get<int16_t>(group_by_values[i]); int16_t g_val = std::get<int16_t>(group_by_values[i]);
ASSERT_FALSE(i16_set.count(g_val) > i16_map[g_val] += 1;
0); //no repetition on groupBy field ASSERT_TRUE(i16_map[g_val] <= group_size);
i16_set.insert(g_val);
auto distance = search_result->distances_.at(i); auto distance = search_result->distances_.at(i);
ASSERT_TRUE( ASSERT_TRUE(
lastDistance <= lastDistance <=
distance); //distance should be decreased as metrics_type is L2 distance); //distance should be decreased as metrics_type is L2
lastDistance = distance; lastDistance = distance;
} else {
//check padding
ASSERT_EQ(search_result->seg_offsets_[i], INVALID_SEG_OFFSET);
ASSERT_EQ(search_result->distances_[i], 0.0);
} }
} }
ASSERT_TRUE(i16_map.size() == 7);
} }
//6. search group by int32
//4. search group by int32
{ {
const char* raw_plan = R"(vector_anns: < const char* raw_plan = R"(vector_anns: <
field_id: 100 field_id: 100
@ -233,14 +232,16 @@ TEST(GroupBY, SealedIndex) {
metric_type: "L2" metric_type: "L2"
search_params: "{\"ef\": 10}" search_params: "{\"ef\": 10}"
group_by_field_id: 103 group_by_field_id: 103
group_size: 3
> >
placeholder_tag: "$0" placeholder_tag: "$0"
>)"; >)";
auto plan_str = translate_text_plan_to_binary_plan(raw_plan); proto::plan::PlanNode plan_node;
auto plan = auto ok =
CreateSearchPlanByExpr(*schema, plan_str.data(), plan_str.size()); google::protobuf::TextFormat::ParseFromString(raw_plan, &plan_node);
auto plan = CreateSearchPlanFromPlanNode(*schema, plan_node);
auto num_queries = 1; auto num_queries = 1;
auto seed = 1024; auto seed = 1024;
auto ph_group_raw = CreatePlaceholderGroup(num_queries, dim, seed); auto ph_group_raw = CreatePlaceholderGroup(num_queries, dim, seed);
@ -249,31 +250,31 @@ TEST(GroupBY, SealedIndex) {
auto search_result = auto search_result =
segment->Search(plan.get(), ph_group.get(), 1L << 63); segment->Search(plan.get(), ph_group.get(), 1L << 63);
CheckGroupBySearchResult(*search_result, topK, num_queries, false); CheckGroupBySearchResult(*search_result, topK, num_queries, false);
auto& group_by_values = search_result->group_by_values_.value(); auto& group_by_values = search_result->group_by_values_.value();
int size = group_by_values.size(); int size = group_by_values.size();
ASSERT_EQ(20, size);
//as the total data is 0,0,....6,6, so there will be 7 buckets with [3,3,3,3,3,3,2] items respectively
//so there will be 20 items returned
std::unordered_set<int32_t> i32_set; std::unordered_map<int32_t, int> i32_map;
float lastDistance = 0.0; float lastDistance = 0.0;
for (size_t i = 0; i < size; i++) { for (size_t i = 0; i < size; i++) {
if (std::holds_alternative<int32_t>(group_by_values[i])) { if (std::holds_alternative<int32_t>(group_by_values[i])) {
int16_t g_val = std::get<int32_t>(group_by_values[i]); int16_t g_val = std::get<int32_t>(group_by_values[i]);
ASSERT_FALSE(i32_set.count(g_val) > i32_map[g_val] += 1;
0); //no repetition on groupBy field ASSERT_TRUE(i32_map[g_val] <= group_size);
i32_set.insert(g_val);
auto distance = search_result->distances_.at(i); auto distance = search_result->distances_.at(i);
ASSERT_TRUE( ASSERT_TRUE(
lastDistance <= lastDistance <=
distance); //distance should be decreased as metrics_type is L2 distance); //distance should be decreased as metrics_type is L2
lastDistance = distance; lastDistance = distance;
} else {
//check padding
ASSERT_EQ(search_result->seg_offsets_[i], INVALID_SEG_OFFSET);
ASSERT_EQ(search_result->distances_[i], 0.0);
} }
} }
ASSERT_TRUE(i32_map.size() == 7);
} }
//4. search group by int64 //7. search group by int64
{ {
const char* raw_plan = R"(vector_anns: < const char* raw_plan = R"(vector_anns: <
field_id: 100 field_id: 100
@ -282,14 +283,16 @@ TEST(GroupBY, SealedIndex) {
metric_type: "L2" metric_type: "L2"
search_params: "{\"ef\": 10}" search_params: "{\"ef\": 10}"
group_by_field_id: 104 group_by_field_id: 104
group_size: 3
> >
placeholder_tag: "$0" placeholder_tag: "$0"
>)"; >)";
auto plan_str = translate_text_plan_to_binary_plan(raw_plan); proto::plan::PlanNode plan_node;
auto plan = auto ok =
CreateSearchPlanByExpr(*schema, plan_str.data(), plan_str.size()); google::protobuf::TextFormat::ParseFromString(raw_plan, &plan_node);
auto plan = CreateSearchPlanFromPlanNode(*schema, plan_node);
auto num_queries = 1; auto num_queries = 1;
auto seed = 1024; auto seed = 1024;
auto ph_group_raw = CreatePlaceholderGroup(num_queries, dim, seed); auto ph_group_raw = CreatePlaceholderGroup(num_queries, dim, seed);
@ -299,30 +302,29 @@ TEST(GroupBY, SealedIndex) {
segment->Search(plan.get(), ph_group.get(), 1L << 63); segment->Search(plan.get(), ph_group.get(), 1L << 63);
CheckGroupBySearchResult(*search_result, topK, num_queries, false); CheckGroupBySearchResult(*search_result, topK, num_queries, false);
auto& group_by_values = search_result->group_by_values_.value(); auto& group_by_values = search_result->group_by_values_.value();
int size = group_by_values.size(); int size = group_by_values.size();
std::unordered_set<int64_t> i64_set; ASSERT_EQ(20, size);
//as the total data is 0,0,....6,6, so there will be 7 buckets with [3,3,3,3,3,3,2] items respectively
//so there will be 20 items returned
std::unordered_map<int64_t, int> i64_map;
float lastDistance = 0.0; float lastDistance = 0.0;
for (size_t i = 0; i < size; i++) { for (size_t i = 0; i < size; i++) {
if (std::holds_alternative<int64_t>(group_by_values[i])) { if (std::holds_alternative<int64_t>(group_by_values[i])) {
int16_t g_val = std::get<int64_t>(group_by_values[i]); int16_t g_val = std::get<int64_t>(group_by_values[i]);
ASSERT_FALSE(i64_set.count(g_val) > i64_map[g_val] += 1;
0); //no repetition on groupBy field ASSERT_TRUE(i64_map[g_val] <= group_size);
i64_set.insert(g_val);
auto distance = search_result->distances_.at(i); auto distance = search_result->distances_.at(i);
ASSERT_TRUE( ASSERT_TRUE(
lastDistance <= lastDistance <=
distance); //distance should be decreased as metrics_type is L2 distance); //distance should be decreased as metrics_type is L2
lastDistance = distance; lastDistance = distance;
} else {
//check padding
ASSERT_EQ(search_result->seg_offsets_[i], INVALID_SEG_OFFSET);
ASSERT_EQ(search_result->distances_[i], 0.0);
} }
} }
ASSERT_TRUE(i64_map.size() == 7);
} }
//4. search group by string //8. search group by string
{ {
const char* raw_plan = R"(vector_anns: < const char* raw_plan = R"(vector_anns: <
field_id: 100 field_id: 100
@ -331,14 +333,16 @@ TEST(GroupBY, SealedIndex) {
metric_type: "L2" metric_type: "L2"
search_params: "{\"ef\": 10}" search_params: "{\"ef\": 10}"
group_by_field_id: 105 group_by_field_id: 105
group_size: 3
> >
placeholder_tag: "$0" placeholder_tag: "$0"
>)"; >)";
auto plan_str = translate_text_plan_to_binary_plan(raw_plan); proto::plan::PlanNode plan_node;
auto plan = auto ok =
CreateSearchPlanByExpr(*schema, plan_str.data(), plan_str.size()); google::protobuf::TextFormat::ParseFromString(raw_plan, &plan_node);
auto plan = CreateSearchPlanFromPlanNode(*schema, plan_node);
auto num_queries = 1; auto num_queries = 1;
auto seed = 1024; auto seed = 1024;
auto ph_group_raw = CreatePlaceholderGroup(num_queries, dim, seed); auto ph_group_raw = CreatePlaceholderGroup(num_queries, dim, seed);
@ -348,30 +352,28 @@ TEST(GroupBY, SealedIndex) {
segment->Search(plan.get(), ph_group.get(), 1L << 63); segment->Search(plan.get(), ph_group.get(), 1L << 63);
CheckGroupBySearchResult(*search_result, topK, num_queries, false); CheckGroupBySearchResult(*search_result, topK, num_queries, false);
auto& group_by_values = search_result->group_by_values_.value(); auto& group_by_values = search_result->group_by_values_.value();
ASSERT_EQ(20, group_by_values.size());
int size = group_by_values.size(); int size = group_by_values.size();
std::unordered_set<std::string> strs_set;
std::unordered_map<std::string, int> strs_map;
float lastDistance = 0.0; float lastDistance = 0.0;
for (size_t i = 0; i < size; i++) { for (size_t i = 0; i < size; i++) {
if (std::holds_alternative<std::string>(group_by_values[i])) { if (std::holds_alternative<std::string>(group_by_values[i])) {
std::string g_val = std::string g_val =
std::move(std::get<std::string>(group_by_values[i])); std::move(std::get<std::string>(group_by_values[i]));
ASSERT_FALSE(strs_set.count(g_val) > strs_map[g_val] += 1;
0); //no repetition on groupBy field ASSERT_TRUE(strs_map[g_val] <= group_size);
strs_set.insert(g_val);
auto distance = search_result->distances_.at(i); auto distance = search_result->distances_.at(i);
ASSERT_TRUE( ASSERT_TRUE(
lastDistance <= lastDistance <=
distance); //distance should be decreased as metrics_type is L2 distance); //distance should be decreased as metrics_type is L2
lastDistance = distance; lastDistance = distance;
} else {
//check padding
ASSERT_EQ(search_result->seg_offsets_[i], INVALID_SEG_OFFSET);
ASSERT_EQ(search_result->distances_[i], 0.0);
} }
} }
ASSERT_TRUE(strs_map.size() == 7);
} }
//4. search group by bool //9. search group by bool
{ {
const char* raw_plan = R"(vector_anns: < const char* raw_plan = R"(vector_anns: <
field_id: 100 field_id: 100
@ -380,14 +382,16 @@ TEST(GroupBY, SealedIndex) {
metric_type: "L2" metric_type: "L2"
search_params: "{\"ef\": 10}" search_params: "{\"ef\": 10}"
group_by_field_id: 106 group_by_field_id: 106
group_size: 3
> >
placeholder_tag: "$0" placeholder_tag: "$0"
>)"; >)";
auto plan_str = translate_text_plan_to_binary_plan(raw_plan); proto::plan::PlanNode plan_node;
auto plan = auto ok =
CreateSearchPlanByExpr(*schema, plan_str.data(), plan_str.size()); google::protobuf::TextFormat::ParseFromString(raw_plan, &plan_node);
auto plan = CreateSearchPlanFromPlanNode(*schema, plan_node);
auto num_queries = 1; auto num_queries = 1;
auto seed = 1024; auto seed = 1024;
auto ph_group_raw = CreatePlaceholderGroup(num_queries, dim, seed); auto ph_group_raw = CreatePlaceholderGroup(num_queries, dim, seed);
@ -396,30 +400,28 @@ TEST(GroupBY, SealedIndex) {
auto search_result = auto search_result =
segment->Search(plan.get(), ph_group.get(), 1L << 63); segment->Search(plan.get(), ph_group.get(), 1L << 63);
CheckGroupBySearchResult(*search_result, topK, num_queries, false); CheckGroupBySearchResult(*search_result, topK, num_queries, false);
auto& group_by_values = search_result->group_by_values_.value(); auto& group_by_values = search_result->group_by_values_.value();
int size = group_by_values.size(); int size = group_by_values.size();
std::unordered_set<bool> bools_set; ASSERT_EQ(size, 6);
int boolValCount = 0; //as there are only two possible values: true, false
//for each group, there are at most 3 items, so the final size of group_by_vals is 3 * 2 = 6
std::unordered_map<bool, int> bools_map;
float lastDistance = 0.0; float lastDistance = 0.0;
for (size_t i = 0; i < size; i++) { for (size_t i = 0; i < size; i++) {
if (std::holds_alternative<bool>(group_by_values[i])) { if (std::holds_alternative<bool>(group_by_values[i])) {
bool g_val = std::get<bool>(group_by_values[i]); bool g_val = std::get<bool>(group_by_values[i]);
ASSERT_FALSE(bools_set.count(g_val) > bools_map[g_val] += 1;
0); //no repetition on groupBy field ASSERT_TRUE(bools_map[g_val] <= group_size);
bools_set.insert(g_val);
boolValCount += 1;
auto distance = search_result->distances_.at(i); auto distance = search_result->distances_.at(i);
ASSERT_TRUE( ASSERT_TRUE(
lastDistance <= lastDistance <=
distance); //distance should be decreased as metrics_type is L2 distance); //distance should be decreased as metrics_type is L2
lastDistance = distance; lastDistance = distance;
} else {
//check padding
ASSERT_EQ(search_result->seg_offsets_[i], INVALID_SEG_OFFSET);
ASSERT_EQ(search_result->distances_[i], 0.0);
} }
ASSERT_TRUE(boolValCount <= 2); //bool values cannot exceed two
} }
ASSERT_TRUE(bools_map.size() == 2); //bool values cannot exceed two
} }
} }
@ -444,7 +446,7 @@ TEST(GroupBY, SealedData) {
size_t N = 100; size_t N = 100;
//2. load raw data //2. load raw data
auto raw_data = DataGen(schema, N); auto raw_data = DataGen(schema, N, 42, 0, 8, 10, false, false);
auto fields = schema->get_fields(); auto fields = schema->get_fields();
for (auto field_data : raw_data.raw_->fields_data()) { for (auto field_data : raw_data.raw_->fields_data()) {
int64_t field_id = field_data.field_id(); int64_t field_id = field_data.field_id();
@ -459,16 +461,18 @@ TEST(GroupBY, SealedData) {
} }
prepareSegmentSystemFieldData(segment, N, raw_data); prepareSegmentSystemFieldData(segment, N, raw_data);
int topK = 100; int topK = 10;
int group_size = 5;
//3. search group by int8 //3. search group by int8
{ {
const char* raw_plan = R"(vector_anns: < const char* raw_plan = R"(vector_anns: <
field_id: 100 field_id: 100
query_info: < query_info: <
topk: 100 topk: 10
metric_type: "L2" metric_type: "L2"
search_params: "{\"ef\": 10}" search_params: "{\"ef\": 10}"
group_by_field_id: 101 group_by_field_id: 101,
group_size: 5,
> >
placeholder_tag: "$0" placeholder_tag: "$0"
@ -487,25 +491,27 @@ TEST(GroupBY, SealedData) {
auto& group_by_values = search_result->group_by_values_.value(); auto& group_by_values = search_result->group_by_values_.value();
int size = group_by_values.size(); int size = group_by_values.size();
std::unordered_set<int8_t> i8_set; //as the repeated is 8, so there will be 13 groups and enough 10 * 5 = 50 results
ASSERT_EQ(50, size);
std::unordered_map<int8_t, int> i8_map;
float lastDistance = 0.0; float lastDistance = 0.0;
for (size_t i = 0; i < size; i++) { for (size_t i = 0; i < size; i++) {
if (std::holds_alternative<int8_t>(group_by_values[i])) { if (std::holds_alternative<int8_t>(group_by_values[i])) {
int8_t g_val = std::get<int8_t>(group_by_values[i]); int8_t g_val = std::get<int8_t>(group_by_values[i]);
ASSERT_FALSE(i8_set.count(g_val) > i8_map[g_val] += 1;
0); //no repetition on groupBy field ASSERT_TRUE(i8_map[g_val] <= group_size);
i8_set.insert(g_val);
auto distance = search_result->distances_.at(i); auto distance = search_result->distances_.at(i);
ASSERT_TRUE( ASSERT_TRUE(
lastDistance <= lastDistance <=
distance); //distance should be decreased as metrics_type is L2 distance); //distance should be decreased as metrics_type is L2
lastDistance = distance; lastDistance = distance;
} else {
//check padding
ASSERT_EQ(search_result->seg_offsets_[i], INVALID_SEG_OFFSET);
ASSERT_EQ(search_result->distances_[i], 0.0);
} }
} }
ASSERT_TRUE(i8_map.size() == topK);
for (const auto& it : i8_map) {
ASSERT_TRUE(it.second == group_size);
}
} }
} }
@ -534,8 +540,10 @@ TEST(GroupBY, Reduce) {
uint64_t ts_offset = 0; uint64_t ts_offset = 0;
int repeat_count_1 = 2; int repeat_count_1 = 2;
int repeat_count_2 = 5; int repeat_count_2 = 5;
auto raw_data1 = DataGen(schema, N, seed, ts_offset, repeat_count_1); auto raw_data1 =
auto raw_data2 = DataGen(schema, N, seed, ts_offset, repeat_count_2); DataGen(schema, N, seed, ts_offset, repeat_count_1, false, false);
auto raw_data2 =
DataGen(schema, N, seed, ts_offset, repeat_count_2, false, false);
auto fields = schema->get_fields(); auto fields = schema->get_fields();
//load segment1 raw data //load segment1 raw data
@ -582,13 +590,17 @@ TEST(GroupBY, Reduce) {
segment2->LoadIndex(load_index_info_2); segment2->LoadIndex(load_index_info_2);
//4. search group by respectively //4. search group by respectively
auto num_queries = 10;
auto topK = 10;
int group_size = 3;
const char* raw_plan = R"(vector_anns: < const char* raw_plan = R"(vector_anns: <
field_id: 100 field_id: 100
query_info: < query_info: <
topk: 100 topk: 10
metric_type: "L2" metric_type: "L2"
search_params: "{\"ef\": 10}" search_params: "{\"ef\": 10}"
group_by_field_id: 101 group_by_field_id: 101
group_size: 3
> >
placeholder_tag: "$0" placeholder_tag: "$0"
@ -596,8 +608,6 @@ TEST(GroupBY, Reduce) {
auto plan_str = translate_text_plan_to_binary_plan(raw_plan); auto plan_str = translate_text_plan_to_binary_plan(raw_plan);
auto plan = auto plan =
CreateSearchPlanByExpr(*schema, plan_str.data(), plan_str.size()); CreateSearchPlanByExpr(*schema, plan_str.data(), plan_str.size());
auto num_queries = 10;
auto topK = 100;
auto ph_group_raw = CreatePlaceholderGroup(num_queries, dim, seed); auto ph_group_raw = CreatePlaceholderGroup(num_queries, dim, seed);
auto ph_group = auto ph_group =
ParsePlaceholderGroup(plan.get(), ph_group_raw.SerializeAsString()); ParsePlaceholderGroup(plan.get(), ph_group_raw.SerializeAsString());
@ -629,7 +639,7 @@ TEST(GroupBY, Reduce) {
slice_nqs.data(), slice_nqs.data(),
slice_topKs.data(), slice_topKs.data(),
slice_nqs.size()); slice_nqs.size());
CheckSearchResultDuplicate(results); CheckSearchResultDuplicate(results, group_size);
DeleteSearchResult(c_search_res_1); DeleteSearchResult(c_search_res_1);
DeleteSearchResult(c_search_res_2); DeleteSearchResult(c_search_res_2);
DeleteSearchResultDataBlobs(cSearchResultData); DeleteSearchResultDataBlobs(cSearchResultData);
@ -664,7 +674,8 @@ TEST(GroupBY, GrowingRawData) {
int64_t rows_per_batch = 512; int64_t rows_per_batch = 512;
int n_batch = 3; int n_batch = 3;
for (int i = 0; i < n_batch; i++) { for (int i = 0; i < n_batch; i++) {
auto data_set = DataGen(schema, rows_per_batch); auto data_set =
DataGen(schema, rows_per_batch, 42, 0, 8, 10, false, false);
auto offset = segment_growing_impl->PreInsert(rows_per_batch); auto offset = segment_growing_impl->PreInsert(rows_per_batch);
segment_growing_impl->Insert(offset, segment_growing_impl->Insert(offset,
rows_per_batch, rows_per_batch,
@ -674,6 +685,9 @@ TEST(GroupBY, GrowingRawData) {
} }
//2. Search group by //2. Search group by
auto num_queries = 10;
auto topK = 100;
int group_size = 1;
const char* raw_plan = R"(vector_anns: < const char* raw_plan = R"(vector_anns: <
field_id: 102 field_id: 102
query_info: < query_info: <
@ -681,6 +695,7 @@ TEST(GroupBY, GrowingRawData) {
metric_type: "L2" metric_type: "L2"
search_params: "{\"ef\": 10}" search_params: "{\"ef\": 10}"
group_by_field_id: 101 group_by_field_id: 101
group_size: 1
> >
placeholder_tag: "$0" placeholder_tag: "$0"
@ -688,8 +703,6 @@ TEST(GroupBY, GrowingRawData) {
auto plan_str = translate_text_plan_to_binary_plan(raw_plan); auto plan_str = translate_text_plan_to_binary_plan(raw_plan);
auto plan = auto plan =
CreateSearchPlanByExpr(*schema, plan_str.data(), plan_str.size()); CreateSearchPlanByExpr(*schema, plan_str.data(), plan_str.size());
auto num_queries = 10;
auto topK = 100;
auto ph_group_raw = CreatePlaceholderGroup(num_queries, dim, seed); auto ph_group_raw = CreatePlaceholderGroup(num_queries, dim, seed);
auto ph_group = auto ph_group =
ParsePlaceholderGroup(plan.get(), ph_group_raw.SerializeAsString()); ParsePlaceholderGroup(plan.get(), ph_group_raw.SerializeAsString());
@ -698,25 +711,25 @@ TEST(GroupBY, GrowingRawData) {
CheckGroupBySearchResult(*search_result, topK, num_queries, true); CheckGroupBySearchResult(*search_result, topK, num_queries, true);
auto& group_by_values = search_result->group_by_values_.value(); auto& group_by_values = search_result->group_by_values_.value();
int size = group_by_values.size();
ASSERT_EQ(size, 640);
//as the number of data is 512 and repeated count is 8, the group number is 64 for every query
//and the total group number should be 640
int expected_group_count = 64;
int idx = 0; int idx = 0;
for (int i = 0; i < num_queries; i++) { for (int i = 0; i < num_queries; i++) {
std::unordered_set<int32_t> i32_set; std::unordered_set<int32_t> i32_set;
float lastDistance = 0.0; float lastDistance = 0.0;
for (int j = 0; j < topK; j++) { for (int j = 0; j < expected_group_count; j++) {
if (std::holds_alternative<int32_t>(group_by_values[idx])) { if (std::holds_alternative<int32_t>(group_by_values[idx])) {
int32_t g_val = std::get<int32_t>(group_by_values[idx]); int32_t g_val = std::get<int32_t>(group_by_values[idx]);
ASSERT_FALSE(i32_set.count(g_val) > ASSERT_FALSE(
0); //no repetition on groupBy field i32_set.count(g_val) >
0); //as the group_size is 1, there should not be any duplication for group_by value
i32_set.insert(g_val); i32_set.insert(g_val);
auto distance = search_result->distances_.at(idx); auto distance = search_result->distances_.at(idx);
ASSERT_TRUE( ASSERT_TRUE(lastDistance <= distance);
lastDistance <=
distance); //distance should be decreased as metrics_type is L2
lastDistance = distance; lastDistance = distance;
} else {
//check padding
ASSERT_EQ(search_result->seg_offsets_[idx], INVALID_SEG_OFFSET);
ASSERT_EQ(search_result->distances_[idx], 0.0);
} }
idx++; idx++;
} }
@ -760,7 +773,8 @@ TEST(GroupBY, GrowingIndex) {
int64_t rows_per_batch = 1024; int64_t rows_per_batch = 1024;
int n_batch = 10; int n_batch = 10;
for (int i = 0; i < n_batch; i++) { for (int i = 0; i < n_batch; i++) {
auto data_set = DataGen(schema, rows_per_batch); auto data_set =
DataGen(schema, rows_per_batch, 42, 0, 8, 10, false, false);
auto offset = segment_growing_impl->PreInsert(rows_per_batch); auto offset = segment_growing_impl->PreInsert(rows_per_batch);
segment_growing_impl->Insert(offset, segment_growing_impl->Insert(offset,
rows_per_batch, rows_per_batch,
@ -770,6 +784,9 @@ TEST(GroupBY, GrowingIndex) {
} }
//2. Search group by int32 //2. Search group by int32
auto num_queries = 10;
auto topK = 100;
int group_size = 3;
const char* raw_plan = R"(vector_anns: < const char* raw_plan = R"(vector_anns: <
field_id: 102 field_id: 102
query_info: < query_info: <
@ -777,6 +794,7 @@ TEST(GroupBY, GrowingIndex) {
metric_type: "L2" metric_type: "L2"
search_params: "{\"ef\": 10}" search_params: "{\"ef\": 10}"
group_by_field_id: 101 group_by_field_id: 101
group_size: 3
> >
placeholder_tag: "$0" placeholder_tag: "$0"
@ -784,8 +802,6 @@ TEST(GroupBY, GrowingIndex) {
auto plan_str = translate_text_plan_to_binary_plan(raw_plan); auto plan_str = translate_text_plan_to_binary_plan(raw_plan);
auto plan = auto plan =
CreateSearchPlanByExpr(*schema, plan_str.data(), plan_str.size()); CreateSearchPlanByExpr(*schema, plan_str.data(), plan_str.size());
auto num_queries = 10;
auto topK = 100;
auto ph_group_raw = CreatePlaceholderGroup(num_queries, dim, seed); auto ph_group_raw = CreatePlaceholderGroup(num_queries, dim, seed);
auto ph_group = auto ph_group =
ParsePlaceholderGroup(plan.get(), ph_group_raw.SerializeAsString()); ParsePlaceholderGroup(plan.get(), ph_group_raw.SerializeAsString());
@ -794,27 +810,29 @@ TEST(GroupBY, GrowingIndex) {
CheckGroupBySearchResult(*search_result, topK, num_queries, true); CheckGroupBySearchResult(*search_result, topK, num_queries, true);
auto& group_by_values = search_result->group_by_values_.value(); auto& group_by_values = search_result->group_by_values_.value();
auto size = group_by_values.size();
int expected_group_count = 100;
ASSERT_EQ(size, expected_group_count * group_size * num_queries);
int idx = 0; int idx = 0;
for (int i = 0; i < num_queries; i++) { for (int i = 0; i < num_queries; i++) {
std::unordered_set<int32_t> i32_set; std::unordered_map<int32_t, int> i32_map;
float lastDistance = 0.0; float lastDistance = 0.0;
for (int j = 0; j < topK; j++) { for (int j = 0; j < expected_group_count * group_size; j++) {
if (std::holds_alternative<int32_t>(group_by_values[idx])) { if (std::holds_alternative<int32_t>(group_by_values[idx])) {
int32_t g_val = std::get<int32_t>(group_by_values[idx]); int32_t g_val = std::get<int32_t>(group_by_values[idx]);
ASSERT_FALSE(i32_set.count(g_val) > i32_map[g_val] += 1;
0); //no repetition on groupBy field ASSERT_TRUE(i32_map[g_val] <= group_size);
i32_set.insert(g_val);
auto distance = search_result->distances_.at(idx); auto distance = search_result->distances_.at(idx);
ASSERT_TRUE( ASSERT_TRUE(
lastDistance <= lastDistance <=
distance); //distance should be decreased as metrics_type is L2 distance); //distance should be decreased as metrics_type is L2
lastDistance = distance; lastDistance = distance;
} else {
//check padding
ASSERT_EQ(search_result->seg_offsets_[idx], INVALID_SEG_OFFSET);
ASSERT_EQ(search_result->distances_[idx], 0.0);
} }
idx++; idx++;
} }
ASSERT_EQ(i32_map.size(), expected_group_count);
for (const auto& map_pair : i32_map) {
ASSERT_EQ(group_size, map_pair.second);
}
} }
} }

View File

@ -28,7 +28,7 @@
#include "knowhere/comp/index_param.h" #include "knowhere/comp/index_param.h"
#include "nlohmann/json.hpp" #include "nlohmann/json.hpp"
#include "query/SearchBruteForce.h" #include "query/SearchBruteForce.h"
#include "segcore/Reduce.h" #include "segcore/reduce/Reduce.h"
#include "index/IndexFactory.h" #include "index/IndexFactory.h"
#include "common/QueryResult.h" #include "common/QueryResult.h"
#include "segcore/Types.h" #include "segcore/Types.h"

View File

@ -212,3 +212,13 @@ TEST(Util, get_common_prefix) {
common_prefix = milvus::GetCommonPrefix(str1, str2); common_prefix = milvus::GetCommonPrefix(str1, str2);
EXPECT_STREQ(common_prefix.c_str(), ""); EXPECT_STREQ(common_prefix.c_str(), "");
} }
TEST(Util, dis_closer){
EXPECT_TRUE(milvus::query::dis_closer(0.1, 0.2, "L2"));
EXPECT_FALSE(milvus::query::dis_closer(0.2, 0.1, "L2"));
EXPECT_FALSE(milvus::query::dis_closer(0.1, 0.1, "L2"));
EXPECT_TRUE(milvus::query::dis_closer(0.2, 0.1, "IP"));
EXPECT_FALSE(milvus::query::dis_closer(0.1, 0.2, "IP"));
EXPECT_FALSE(milvus::query::dis_closer(0.1, 0.1, "IP"));
}

View File

@ -236,7 +236,8 @@ struct GeneratedData {
uint64_t ts_offset, uint64_t ts_offset,
int repeat_count, int repeat_count,
int array_len, int array_len,
bool random_pk); bool random_pk,
bool random_val);
friend GeneratedData friend GeneratedData
DataGenForJsonArray(SchemaPtr schema, DataGenForJsonArray(SchemaPtr schema,
int64_t N, int64_t N,
@ -307,7 +308,8 @@ inline GeneratedData DataGen(SchemaPtr schema,
uint64_t ts_offset = 0, uint64_t ts_offset = 0,
int repeat_count = 1, int repeat_count = 1,
int array_len = 10, int array_len = 10,
bool random_pk = false) { bool random_pk = false,
bool random_val = true) {
using std::vector; using std::vector;
std::default_random_engine random(seed); std::default_random_engine random(seed);
std::normal_distribution<> distr(0, 1); std::normal_distribution<> distr(0, 1);
@ -414,24 +416,39 @@ inline GeneratedData DataGen(SchemaPtr schema,
} }
case DataType::INT32: { case DataType::INT32: {
vector<int> data(N); vector<int> data(N);
for (auto& x : data) { for (int i = 0; i < N; i++) {
x = random() % (2 * N); int x = 0;
if (random_val)
x = random() % (2 * N);
else
x = i / repeat_count;
data[i] = x;
} }
insert_cols(data, N, field_meta); insert_cols(data, N, field_meta);
break; break;
} }
case DataType::INT16: { case DataType::INT16: {
vector<int16_t> data(N); vector<int16_t> data(N);
for (auto& x : data) { for (int i = 0; i < N; i++) {
x = random() % (2 * N); int16_t x = 0;
if (random_val)
x = random() % (2 * N);
else
x = i / repeat_count;
data[i] = x;
} }
insert_cols(data, N, field_meta); insert_cols(data, N, field_meta);
break; break;
} }
case DataType::INT8: { case DataType::INT8: {
vector<int8_t> data(N); vector<int8_t> data(N);
for (auto& x : data) { for (int i = 0; i < N; i++) {
x = random() % (2 * N); int8_t x = 0;
if (random_val)
x = random() % (2 * N);
else
x = i / repeat_count;
data[i] = x;
} }
insert_cols(data, N, field_meta); insert_cols(data, N, field_meta);
break; break;
@ -1175,7 +1192,6 @@ translate_text_plan_to_binary_plan(const char* text_plan) {
auto ok = auto ok =
google::protobuf::TextFormat::ParseFromString(text_plan, &plan_node); google::protobuf::TextFormat::ParseFromString(text_plan, &plan_node);
AssertInfo(ok, "Failed to parse"); AssertInfo(ok, "Failed to parse");
std::string binary_plan; std::string binary_plan;
plan_node.SerializeToString(&binary_plan); plan_node.SerializeToString(&binary_plan);

View File

@ -25,7 +25,7 @@
#include "common/type_c.h" #include "common/type_c.h"
#include "pb/plan.pb.h" #include "pb/plan.pb.h"
#include "segcore/Collection.h" #include "segcore/Collection.h"
#include "segcore/Reduce.h" #include "segcore/reduce/Reduce.h"
#include "segcore/reduce_c.h" #include "segcore/reduce_c.h"
#include "segcore/segment_c.h" #include "segcore/segment_c.h"
#include "futures/Future.h" #include "futures/Future.h"
@ -86,14 +86,14 @@ generate_query_data(int nq) {
return blob; return blob;
} }
void void
CheckSearchResultDuplicate(const std::vector<CSearchResult>& results) { CheckSearchResultDuplicate(const std::vector<CSearchResult>& results,
int group_size = 1) {
auto nq = ((SearchResult*)results[0])->total_nq_; auto nq = ((SearchResult*)results[0])->total_nq_;
std::unordered_set<PkType> pk_set; std::unordered_set<PkType> pk_set;
std::unordered_set<GroupByValueType> group_by_val_set; std::unordered_map<GroupByValueType, int> group_by_map;
for (int qi = 0; qi < nq; qi++) { for (int qi = 0; qi < nq; qi++) {
pk_set.clear(); pk_set.clear();
group_by_val_set.clear(); group_by_map.clear();
for (size_t i = 0; i < results.size(); i++) { for (size_t i = 0; i < results.size(); i++) {
auto search_result = (SearchResult*)results[i]; auto search_result = (SearchResult*)results[i];
ASSERT_EQ(nq, search_result->total_nq_); ASSERT_EQ(nq, search_result->total_nq_);
@ -108,8 +108,8 @@ CheckSearchResultDuplicate(const std::vector<CSearchResult>& results) {
search_result->group_by_values_.value().size() > ki) { search_result->group_by_values_.value().size() > ki) {
auto group_by_val = auto group_by_val =
search_result->group_by_values_.value()[ki]; search_result->group_by_values_.value()[ki];
ASSERT_TRUE(group_by_val_set.count(group_by_val) == 0); group_by_map[group_by_val] += 1;
group_by_val_set.insert(group_by_val); ASSERT_TRUE(group_by_map[group_by_val] <= group_size);
} }
} }
} }

View File

@ -61,6 +61,7 @@ message QueryInfo {
int64 round_decimal = 5; int64 round_decimal = 5;
int64 group_by_field_id = 6; int64 group_by_field_id = 6;
bool materialized_view_involved = 7; bool materialized_view_involved = 7;
int64 group_size = 8;
} }
message ColumnInfo { message ColumnInfo {