// Licensed to the LF AI & Data foundation under one // or more contributor license agreements. See the NOTICE file // distributed with this work for additional information // regarding copyright ownership. The ASF licenses this file // to you under the Apache License, Version 2.0 (the // "License"); you may not use this file except in compliance // with the License. You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #include "GroupByOperator.h" #include "common/Consts.h" #include "segcore/SegmentSealedImpl.h" #include "Utils.h" namespace milvus { namespace query { void GroupBy(const std::vector>& iterators, const SearchInfo& search_info, std::vector& group_by_values, const segcore::SegmentInternalInterface& segment, std::vector& seg_offsets, std::vector& distances) { //1. get search meta FieldId group_by_field_id = search_info.group_by_field_id_.value(); auto data_type = segment.GetFieldDataType(group_by_field_id); switch (data_type) { case DataType::INT8: { auto dataGetter = GetDataGetter(segment, group_by_field_id); GroupIteratorsByType(iterators, search_info.topk_, *dataGetter, group_by_values, seg_offsets, distances, search_info.metric_type_); break; } case DataType::INT16: { auto dataGetter = GetDataGetter(segment, group_by_field_id); GroupIteratorsByType(iterators, search_info.topk_, *dataGetter, group_by_values, seg_offsets, distances, search_info.metric_type_); break; } case DataType::INT32: { auto dataGetter = GetDataGetter(segment, group_by_field_id); GroupIteratorsByType(iterators, search_info.topk_, *dataGetter, group_by_values, seg_offsets, distances, search_info.metric_type_); break; } case DataType::INT64: { auto dataGetter = GetDataGetter(segment, group_by_field_id); GroupIteratorsByType(iterators, search_info.topk_, *dataGetter, group_by_values, seg_offsets, distances, search_info.metric_type_); break; } case DataType::BOOL: { auto dataGetter = GetDataGetter(segment, group_by_field_id); GroupIteratorsByType(iterators, search_info.topk_, *dataGetter, group_by_values, seg_offsets, distances, search_info.metric_type_); break; } case DataType::VARCHAR: { auto dataGetter = GetDataGetter(segment, group_by_field_id); GroupIteratorsByType(iterators, search_info.topk_, *dataGetter, group_by_values, seg_offsets, distances, search_info.metric_type_); break; } default: { PanicInfo( Unsupported, fmt::format("unsupported data type {} for group by operator", data_type)); } } } template void GroupIteratorsByType( const std::vector>& iterators, int64_t topK, const DataGetter& data_getter, std::vector& group_by_values, std::vector& seg_offsets, std::vector& distances, const knowhere::MetricType& metrics_type) { for (auto& iterator : iterators) { GroupIteratorResult(iterator, topK, data_getter, group_by_values, seg_offsets, distances, metrics_type); } } template void GroupIteratorResult(const std::shared_ptr& iterator, int64_t topK, const DataGetter& data_getter, std::vector& group_by_values, std::vector& offsets, std::vector& distances, const knowhere::MetricType& metrics_type) { //1. std::unordered_map> groupMap; //2. do iteration until fill the whole map or run out of all data //note it may enumerate all data inside a segment and can block following //query and search possibly auto dis_closer = [&](float l, float r) { if (PositivelyRelated(metrics_type)) return l > r; return l < r; }; while (iterator->HasNext() && groupMap.size() < topK) { auto offset_dis_pair = iterator->Next(); AssertInfo( offset_dis_pair.has_value(), "Wrong state! iterator cannot return valid result whereas it still" "tells hasNext, terminate groupBy operation"); auto offset = offset_dis_pair.value().first; auto dis = offset_dis_pair.value().second; T row_data = data_getter.Get(offset); auto it = groupMap.find(row_data); if (it == groupMap.end()) { groupMap.emplace(row_data, std::make_pair(offset, dis)); } else if (dis_closer(dis, it->second.second)) { it->second = {offset, dis}; } } //3. sorted based on distances and metrics std::vector>> sortedGroupVals( groupMap.begin(), groupMap.end()); auto customComparator = [&](const auto& lhs, const auto& rhs) { return dis_closer(lhs.second.second, rhs.second.second); }; std::sort(sortedGroupVals.begin(), sortedGroupVals.end(), customComparator); //4. save groupBy results group_by_values.reserve(sortedGroupVals.size()); offsets.reserve(sortedGroupVals.size()); distances.reserve(sortedGroupVals.size()); for (auto iter = sortedGroupVals.cbegin(); iter != sortedGroupVals.cend(); iter++) { group_by_values.emplace_back(iter->first); offsets.push_back(iter->second.first); distances.push_back(iter->second.second); } //5. padding topK results, extra memory consumed will be removed when reducing for (std::size_t idx = groupMap.size(); idx < topK; idx++) { offsets.push_back(INVALID_SEG_OFFSET); distances.push_back(0.0); group_by_values.emplace_back(std::monostate{}); } } } // namespace query } // namespace milvus