// Licensed to the LF AI & Data foundation under one // or more contributor license agreements. See the NOTICE file // distributed with this work for additional information // regarding copyright ownership. The ASF licenses this file // to you under the Apache License, Version 2.0 (the // "License"); you may not use this file except in compliance // with the License. You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #include #include #include "ExprImpl.h" #include "Parser.h" #include "Plan.h" #include "generated/ExtractInfoPlanNodeVisitor.h" #include "generated/VerifyPlanNodeVisitor.h" namespace milvus::query { template static ExprPtr ConstructTree(Merger merger, std::vector item_list) { if (item_list.size() == 0) { return nullptr; } if (item_list.size() == 1) { return std::move(item_list[0]); } // Note: use deque to construct a binary tree // Op // / \ // Op Op // | \ | \ // A B C D std::deque binary_queue; for (auto& item : item_list) { Assert(item != nullptr); binary_queue.push_back(std::move(item)); } while (binary_queue.size() > 1) { auto left = std::move(binary_queue.front()); binary_queue.pop_front(); auto right = std::move(binary_queue.front()); binary_queue.pop_front(); binary_queue.push_back(merger(std::move(left), std::move(right))); } Assert(binary_queue.size() == 1); return std::move(binary_queue.front()); } ExprPtr Parser::ParseCompareNode(const Json& out_body) { Assert(out_body.is_object()); Assert(out_body.size() == 1); auto out_iter = out_body.begin(); auto op_name = boost::algorithm::to_lower_copy(std::string(out_iter.key())); AssertInfo(mapping_.count(op_name), "op(" + op_name + ") not found"); auto body = out_iter.value(); Assert(body.is_array()); Assert(body.size() == 2); auto expr = std::make_unique(); expr->op_type_ = mapping_.at(op_name); auto& item0 = body[0]; Assert(item0.is_string()); auto left_field_name = FieldName(item0.get()); expr->left_data_type_ = schema[left_field_name].get_data_type(); expr->left_field_id_ = schema.get_field_id(left_field_name); auto& item1 = body[1]; Assert(item1.is_string()); auto right_field_name = FieldName(item1.get()); expr->right_data_type_ = schema[right_field_name].get_data_type(); expr->right_field_id_ = schema.get_field_id(right_field_name); return expr; } ExprPtr Parser::ParseRangeNode(const Json& out_body) { Assert(out_body.is_object()); Assert(out_body.size() == 1); auto out_iter = out_body.begin(); auto field_name = FieldName(out_iter.key()); auto body = out_iter.value(); auto data_type = schema[field_name].get_data_type(); Assert(!datatype_is_vector(data_type)); switch (data_type) { case DataType::BOOL: return ParseRangeNodeImpl(field_name, body); case DataType::INT8: return ParseRangeNodeImpl(field_name, body); case DataType::INT16: return ParseRangeNodeImpl(field_name, body); case DataType::INT32: return ParseRangeNodeImpl(field_name, body); case DataType::INT64: return ParseRangeNodeImpl(field_name, body); case DataType::FLOAT: return ParseRangeNodeImpl(field_name, body); case DataType::DOUBLE: return ParseRangeNodeImpl(field_name, body); default: PanicInfo("unsupported"); } } std::unique_ptr Parser::CreatePlanImpl(const Json& dsl) { auto bool_dsl = dsl.at("bool"); auto predicate = ParseAnyNode(bool_dsl); Assert(vector_node_opt_.has_value()); auto vec_node = std::move(vector_node_opt_).value(); if (predicate != nullptr) { vec_node->predicate_ = std::move(predicate); } VerifyPlanNodeVisitor verifier; vec_node->accept(verifier); ExtractedPlanInfo plan_info(schema.size()); ExtractInfoPlanNodeVisitor extractor(plan_info); vec_node->accept(extractor); auto plan = std::make_unique(schema); plan->tag2field_ = std::move(tag2field_); plan->plan_node_ = std::move(vec_node); plan->extra_info_opt_ = std::move(plan_info); return plan; } ExprPtr Parser::ParseTermNode(const Json& out_body) { Assert(out_body.size() == 1); auto out_iter = out_body.begin(); auto field_name = FieldName(out_iter.key()); auto body = out_iter.value(); auto data_type = schema[field_name].get_data_type(); Assert(!datatype_is_vector(data_type)); switch (data_type) { case DataType::BOOL: { return ParseTermNodeImpl(field_name, body); } case DataType::INT8: { return ParseTermNodeImpl(field_name, body); } case DataType::INT16: { return ParseTermNodeImpl(field_name, body); } case DataType::INT32: { return ParseTermNodeImpl(field_name, body); } case DataType::INT64: { return ParseTermNodeImpl(field_name, body); } case DataType::FLOAT: { return ParseTermNodeImpl(field_name, body); } case DataType::DOUBLE: { return ParseTermNodeImpl(field_name, body); } default: { PanicInfo("unsupported data_type"); } } } std::unique_ptr Parser::ParseVecNode(const Json& out_body) { Assert(out_body.is_object()); Assert(out_body.size() == 1); auto iter = out_body.begin(); auto field_name = FieldName(iter.key()); auto& vec_info = iter.value(); Assert(vec_info.is_object()); auto topk = vec_info["topk"]; AssertInfo(topk > 0, "topk must greater than 0"); AssertInfo(topk < 16384, "topk is too large"); auto field_id = schema.get_field_id(field_name); auto vec_node = [&]() -> std::unique_ptr { auto& field_meta = schema.operator[](field_name); auto data_type = field_meta.get_data_type(); if (data_type == DataType::VECTOR_FLOAT) { return std::make_unique(); } else { return std::make_unique(); } }(); vec_node->search_info_.topk_ = topk; vec_node->search_info_.metric_type_ = vec_info.at("metric_type"); vec_node->search_info_.search_params_ = vec_info.at("params"); vec_node->search_info_.field_id_ = field_id; vec_node->search_info_.round_decimal_ = vec_info.at("round_decimal"); vec_node->placeholder_tag_ = vec_info.at("query"); auto tag = vec_node->placeholder_tag_; AssertInfo(!tag2field_.count(tag), "duplicated placeholder tag"); tag2field_.emplace(tag, field_id); return vec_node; } template ExprPtr Parser::ParseTermNodeImpl(const FieldName& field_name, const Json& body) { Assert(body.is_object()); auto values = body["values"]; std::vector terms(values.size()); for (int i = 0; i < values.size(); i++) { auto value = values[i]; if constexpr (std::is_same_v) { Assert(value.is_boolean()); } else if constexpr (std::is_integral_v) { Assert(value.is_number_integer()); } else if constexpr (std::is_floating_point_v) { Assert(value.is_number()); } else { static_assert(always_false, "unsupported type"); } terms[i] = value; } std::sort(terms.begin(), terms.end()); return std::make_unique>(schema.get_field_id(field_name), schema[field_name].get_data_type(), terms); } template ExprPtr Parser::ParseRangeNodeImpl(const FieldName& field_name, const Json& body) { Assert(body.is_object()); if (body.size() == 1) { auto item = body.begin(); auto op_name = boost::algorithm::to_lower_copy(std::string(item.key())); AssertInfo(mapping_.count(op_name), "op(" + op_name + ") not found"); // This is an expression with an arithmetic operation if (item.value().is_object()) { /* // This is the expected DSL expression { range: { field_name: { op: { arith_op: { right_operand: operand, value: value }, } } } } EXAMPLE: { range: { field_name: { "EQ": { "ADD": { right_operand: 10, value: 25 }, } } } } */ auto arith = item.value(); auto arith_body = arith.begin(); auto arith_op_name = boost::algorithm::to_lower_copy(std::string(arith_body.key())); AssertInfo(arith_op_mapping_.count(arith_op_name), "arith op(" + arith_op_name + ") not found"); auto& arith_op_body = arith_body.value(); Assert(arith_op_body.is_object()); auto right_operand = arith_op_body["right_operand"]; auto value = arith_op_body["value"]; if constexpr (std::is_same_v) { throw std::runtime_error("bool type is not supported"); } else if constexpr (std::is_integral_v) { Assert(right_operand.is_number_integer()); Assert(value.is_number_integer()); } else if constexpr (std::is_floating_point_v) { Assert(right_operand.is_number()); Assert(value.is_number()); } else { static_assert(always_false, "unsupported type"); } return std::make_unique>( schema.get_field_id(field_name), schema[field_name].get_data_type(), arith_op_mapping_.at(arith_op_name), right_operand, mapping_.at(op_name), value); } if constexpr (std::is_same_v) { Assert(item.value().is_boolean()); } else if constexpr (std::is_integral_v) { Assert(item.value().is_number_integer()); } else if constexpr (std::is_floating_point_v) { Assert(item.value().is_number()); } else { static_assert(always_false, "unsupported type"); } return std::make_unique>( schema.get_field_id(field_name), schema[field_name].get_data_type(), mapping_.at(op_name), item.value()); } else if (body.size() == 2) { bool has_lower_value = false; bool has_upper_value = false; bool lower_inclusive = false; bool upper_inclusive = false; T lower_value; T upper_value; for (auto& item : body.items()) { auto op_name = boost::algorithm::to_lower_copy(std::string(item.key())); AssertInfo(mapping_.count(op_name), "op(" + op_name + ") not found"); if constexpr (std::is_same_v) { Assert(item.value().is_boolean()); } else if constexpr (std::is_integral_v) { Assert(item.value().is_number_integer()); } else if constexpr (std::is_floating_point_v) { Assert(item.value().is_number()); } else { static_assert(always_false, "unsupported type"); } auto op = mapping_.at(op_name); switch (op) { case OpType::GreaterEqual: lower_inclusive = true; case OpType::GreaterThan: lower_value = item.value(); has_lower_value = true; break; case OpType::LessEqual: upper_inclusive = true; case OpType::LessThan: upper_value = item.value(); has_upper_value = true; break; default: PanicInfo("unsupported operator in binary-range node"); } } AssertInfo(has_lower_value && has_upper_value, "illegal binary-range node"); return std::make_unique>( schema.get_field_id(field_name), schema[field_name].get_data_type(), lower_inclusive, upper_inclusive, lower_value, upper_value); } else { PanicInfo("illegal range node, too more or too few ops"); } } std::vector Parser::ParseItemList(const Json& body) { std::vector results; if (body.is_object()) { // only one item; auto new_expr = ParseAnyNode(body); results.emplace_back(std::move(new_expr)); } else { // item array Assert(body.is_array()); for (auto& item : body) { auto new_expr = ParseAnyNode(item); results.emplace_back(std::move(new_expr)); } } auto old_size = results.size(); auto new_end = std::remove_if(results.begin(), results.end(), [](const ExprPtr& x) { return x == nullptr; }); results.resize(new_end - results.begin()); return results; } ExprPtr Parser::ParseAnyNode(const Json& out_body) { Assert(out_body.is_object()); Assert(out_body.size() == 1); auto out_iter = out_body.begin(); auto key = out_iter.key(); auto body = out_iter.value(); if (key == "must") { return ParseMustNode(body); } else if (key == "should") { return ParseShouldNode(body); } else if (key == "must_not") { return ParseMustNotNode(body); } else if (key == "range") { return ParseRangeNode(body); } else if (key == "term") { return ParseTermNode(body); } else if (key == "compare") { return ParseCompareNode(body); } else if (key == "vector") { auto vec_node = ParseVecNode(body); Assert(!vector_node_opt_.has_value()); vector_node_opt_ = std::move(vec_node); return nullptr; } else { PanicInfo("unsupported key: " + key); } } ExprPtr Parser::ParseMustNode(const Json& body) { auto item_list = ParseItemList(body); auto merger = [](ExprPtr left, ExprPtr right) { using OpType = LogicalBinaryExpr::OpType; return std::make_unique( OpType::LogicalAnd, left, right); }; return ConstructTree(merger, std::move(item_list)); } ExprPtr Parser::ParseShouldNode(const Json& body) { auto item_list = ParseItemList(body); Assert(item_list.size() >= 1); auto merger = [](ExprPtr left, ExprPtr right) { using OpType = LogicalBinaryExpr::OpType; return std::make_unique( OpType::LogicalOr, left, right); }; return ConstructTree(merger, std::move(item_list)); } ExprPtr Parser::ParseMustNotNode(const Json& body) { auto item_list = ParseItemList(body); Assert(item_list.size() >= 1); auto merger = [](ExprPtr left, ExprPtr right) { using OpType = LogicalBinaryExpr::OpType; return std::make_unique( OpType::LogicalAnd, left, right); }; auto subtree = ConstructTree(merger, std::move(item_list)); using OpType = LogicalUnaryExpr::OpType; return std::make_unique(OpType::LogicalNot, subtree); } } // namespace milvus::query