// Licensed to the LF AI & Data foundation under one // or more contributor license agreements. See the NOTICE file // distributed with this work for additional information // regarding copyright ownership. The ASF licenses this file // to you under the Apache License, Version 2.0 (the // "License"); you may not use this file except in compliance // with the License. You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #include "CompareExpr.h" #include "query/Relational.h" namespace milvus { namespace exec { bool PhyCompareFilterExpr::IsStringExpr() { return expr_->left_data_type_ == DataType::VARCHAR || expr_->right_data_type_ == DataType::VARCHAR; } int64_t PhyCompareFilterExpr::GetNextBatchSize() { auto current_rows = segment_->type() == SegmentType::Growing ? current_chunk_id_ * size_per_chunk_ + current_chunk_pos_ : current_chunk_pos_; return current_rows + batch_size_ >= num_rows_ ? num_rows_ - current_rows : batch_size_; } template ChunkDataAccessor PhyCompareFilterExpr::GetChunkData(FieldId field_id, int chunk_id, int data_barrier) { if (chunk_id >= data_barrier) { auto& indexing = segment_->chunk_scalar_index(field_id, chunk_id); if (indexing.HasRawData()) { return [&indexing](int i) -> const number { return indexing.Reverse_Lookup(i); }; } } auto chunk_data = segment_->chunk_data(field_id, chunk_id).data(); return [chunk_data](int i) -> const number { return chunk_data[i]; }; } template <> ChunkDataAccessor PhyCompareFilterExpr::GetChunkData(FieldId field_id, int chunk_id, int data_barrier) { if (chunk_id >= data_barrier) { auto& indexing = segment_->chunk_scalar_index(field_id, chunk_id); if (indexing.HasRawData()) { return [&indexing](int i) -> const std::string { return indexing.Reverse_Lookup(i); }; } } if (segment_->type() == SegmentType::Growing) { auto chunk_data = segment_->chunk_data(field_id, chunk_id).data(); return [chunk_data](int i) -> const number { return chunk_data[i]; }; } else { auto chunk_data = segment_->chunk_data(field_id, chunk_id).data(); return [chunk_data](int i) -> const number { return std::string(chunk_data[i]); }; } } ChunkDataAccessor PhyCompareFilterExpr::GetChunkData(DataType data_type, FieldId field_id, int chunk_id, int data_barrier) { switch (data_type) { case DataType::BOOL: return GetChunkData(field_id, chunk_id, data_barrier); case DataType::INT8: return GetChunkData(field_id, chunk_id, data_barrier); case DataType::INT16: return GetChunkData(field_id, chunk_id, data_barrier); case DataType::INT32: return GetChunkData(field_id, chunk_id, data_barrier); case DataType::INT64: return GetChunkData(field_id, chunk_id, data_barrier); case DataType::FLOAT: return GetChunkData(field_id, chunk_id, data_barrier); case DataType::DOUBLE: return GetChunkData(field_id, chunk_id, data_barrier); case DataType::VARCHAR: { return GetChunkData(field_id, chunk_id, data_barrier); } default: PanicInfo(DataTypeInvalid, "unsupported data type: {}", data_type); } } template VectorPtr PhyCompareFilterExpr::ExecCompareExprDispatcher(OpType op) { auto real_batch_size = GetNextBatchSize(); if (real_batch_size == 0) { return nullptr; } auto res_vec = std::make_shared(DataType::BOOL, real_batch_size); bool* res = (bool*)res_vec->GetRawData(); auto left_data_barrier = segment_->num_chunk_data(expr_->left_field_id_); auto right_data_barrier = segment_->num_chunk_data(expr_->right_field_id_); int64_t processed_rows = 0; for (int64_t chunk_id = current_chunk_id_; chunk_id < num_chunk_; ++chunk_id) { auto chunk_size = chunk_id == num_chunk_ - 1 ? num_rows_ - chunk_id * size_per_chunk_ : size_per_chunk_; auto left = GetChunkData(expr_->left_data_type_, expr_->left_field_id_, chunk_id, left_data_barrier); auto right = GetChunkData(expr_->right_data_type_, expr_->right_field_id_, chunk_id, right_data_barrier); for (int i = chunk_id == current_chunk_id_ ? current_chunk_pos_ : 0; i < chunk_size; ++i) { res[processed_rows++] = boost::apply_visitor( milvus::query::Relational{}, left(i), right(i)); if (processed_rows >= batch_size_) { current_chunk_id_ = chunk_id; current_chunk_pos_ = i + 1; return res_vec; } } } return res_vec; } void PhyCompareFilterExpr::Eval(EvalCtx& context, VectorPtr& result) { // For segment both fields has no index, can use SIMD to speed up. // Avoiding too much call stack that blocks SIMD. if (!is_left_indexed_ && !is_right_indexed_ && !IsStringExpr()) { result = ExecCompareExprDispatcherForBothDataSegment(); return; } result = ExecCompareExprDispatcherForHybridSegment(); } VectorPtr PhyCompareFilterExpr::ExecCompareExprDispatcherForHybridSegment() { switch (expr_->op_type_) { case OpType::Equal: { return ExecCompareExprDispatcher(std::equal_to<>{}); } case OpType::NotEqual: { return ExecCompareExprDispatcher(std::not_equal_to<>{}); } case OpType::GreaterEqual: { return ExecCompareExprDispatcher(std::greater_equal<>{}); } case OpType::GreaterThan: { return ExecCompareExprDispatcher(std::greater<>{}); } case OpType::LessEqual: { return ExecCompareExprDispatcher(std::less_equal<>{}); } case OpType::LessThan: { return ExecCompareExprDispatcher(std::less<>{}); } case OpType::PrefixMatch: { return ExecCompareExprDispatcher( milvus::query::MatchOp{}); } // case OpType::PostfixMatch: { // } default: { PanicInfo(OpTypeInvalid, "unsupported optype: {}", expr_->op_type_); } } } VectorPtr PhyCompareFilterExpr::ExecCompareExprDispatcherForBothDataSegment() { switch (expr_->left_data_type_) { case DataType::BOOL: return ExecCompareLeftType(); case DataType::INT8: return ExecCompareLeftType(); case DataType::INT16: return ExecCompareLeftType(); case DataType::INT32: return ExecCompareLeftType(); case DataType::INT64: return ExecCompareLeftType(); case DataType::FLOAT: return ExecCompareLeftType(); case DataType::DOUBLE: return ExecCompareLeftType(); default: PanicInfo( DataTypeInvalid, fmt::format("unsupported left datatype:{} of compare expr", expr_->left_data_type_)); } } template VectorPtr PhyCompareFilterExpr::ExecCompareLeftType() { switch (expr_->right_data_type_) { case DataType::BOOL: return ExecCompareRightType(); case DataType::INT8: return ExecCompareRightType(); case DataType::INT16: return ExecCompareRightType(); case DataType::INT32: return ExecCompareRightType(); case DataType::INT64: return ExecCompareRightType(); case DataType::FLOAT: return ExecCompareRightType(); case DataType::DOUBLE: return ExecCompareRightType(); default: PanicInfo( DataTypeInvalid, fmt::format("unsupported right datatype:{} of compare expr", expr_->right_data_type_)); } } template VectorPtr PhyCompareFilterExpr::ExecCompareRightType() { auto real_batch_size = GetNextBatchSize(); if (real_batch_size == 0) { return nullptr; } auto res_vec = std::make_shared(DataType::BOOL, real_batch_size); bool* res = (bool*)res_vec->GetRawData(); auto expr_type = expr_->op_type_; auto execute_sub_batch = [expr_type](const T* left, const U* right, const int size, bool* res) { switch (expr_type) { case proto::plan::GreaterThan: { CompareElementFunc func; func(left, right, size, res); break; } case proto::plan::GreaterEqual: { CompareElementFunc func; func(left, right, size, res); break; } case proto::plan::LessThan: { CompareElementFunc func; func(left, right, size, res); break; } case proto::plan::LessEqual: { CompareElementFunc func; func(left, right, size, res); break; } case proto::plan::Equal: { CompareElementFunc func; func(left, right, size, res); break; } case proto::plan::NotEqual: { CompareElementFunc func; func(left, right, size, res); break; } default: PanicInfo( OpTypeInvalid, fmt::format( "unsupported operator type for compare column expr: {}", expr_type)); } }; int64_t processed_size = ProcessBothDataChunks(execute_sub_batch, res); AssertInfo(processed_size == real_batch_size, "internal error: expr processed rows {} not equal " "expect batch size {}", processed_size, real_batch_size); return res_vec; }; } //namespace exec } // namespace milvus