// Copyright (C) 2019-2020 Zilliz. All rights reserved. // // Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance // with the License. You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software distributed under the License // is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express // or implied. See the License for the specific language governing permissions and limitations under the License #include "query/Plan.h" #include "query/PlanImpl.h" #include "query/PlanNode.h" #include "query/ExprImpl.h" #include "utils/Json.h" #include "utils/EasyAssert.h" #include "pb/service_msg.pb.h" #include #include #include namespace milvus::query { static std::unique_ptr ParseVecNode(Plan* plan, const Json& out_body) { Assert(out_body.is_object()); // TODO add binary info Assert(out_body.size() == 1); auto iter = out_body.begin(); std::string field_name = iter.key(); auto& vec_info = iter.value(); Assert(vec_info.is_object()); auto topK = vec_info["topk"]; AssertInfo(topK > 0, "topK must greater than 0"); AssertInfo(topK < 16384, "topK is too large"); auto field_meta = plan->schema_.operator[](field_name); auto vec_node = [&]() -> std::unique_ptr { auto data_type = field_meta.get_data_type(); if (data_type == DataType::VECTOR_FLOAT) { return std::make_unique(); } else { return std::make_unique(); } }(); vec_node->query_info_.topK_ = topK; vec_node->query_info_.metric_type_ = vec_info.at("metric_type"); vec_node->query_info_.search_params_ = vec_info.at("params"); vec_node->query_info_.field_id_ = field_name; vec_node->placeholder_tag_ = vec_info.at("query"); auto tag = vec_node->placeholder_tag_; AssertInfo(!plan->tag2field_.count(tag), "duplicated placeholder tag"); plan->tag2field_.emplace(tag, field_name); return vec_node; } /// initialize RangeExpr::mapping_ const std::map RangeExpr::mapping_ = { {"lt", OpType::LessThan}, {"le", OpType::LessEqual}, {"lte", OpType::LessEqual}, {"gt", OpType::GreaterThan}, {"ge", OpType::GreaterEqual}, {"gte", OpType::GreaterEqual}, {"eq", OpType::Equal}, {"ne", OpType::NotEqual}, }; static inline std::string to_lower(const std::string& raw) { auto data = raw; std::transform(data.begin(), data.end(), data.begin(), [](unsigned char c) { return std::tolower(c); }); return data; } template constexpr std::false_type always_false{}; template std::unique_ptr ParseRangeNodeImpl(const Schema& schema, const std::string& field_name, const Json& body) { auto expr = std::make_unique>(); auto data_type = schema[field_name].get_data_type(); expr->data_type_ = data_type; expr->field_id_ = field_name; Assert(body.is_object()); for (auto& item : body.items()) { auto op_name = to_lower(item.key()); AssertInfo(RangeExpr::mapping_.count(op_name), "op(" + op_name + ") not found"); auto op = RangeExpr::mapping_.at(op_name); if constexpr (std::is_integral_v) { Assert(item.value().is_number_integer()); } else if constexpr (std::is_floating_point_v) { Assert(item.value().is_number()); } else { static_assert(always_false, "unsupported type"); } T value = item.value(); expr->conditions_.emplace_back(op, value); } return expr; } std::unique_ptr ParseRangeNode(const Schema& schema, const Json& out_body) { Assert(out_body.size() == 1); auto out_iter = out_body.begin(); auto field_name = out_iter.key(); auto body = out_iter.value(); auto data_type = schema[field_name].get_data_type(); Assert(!field_is_vector(data_type)); switch (data_type) { case DataType::BOOL: { PanicInfo("bool is not supported in Range node"); // return ParseRangeNodeImpl(schema, field_name, body); } case DataType::INT8: return ParseRangeNodeImpl(schema, field_name, body); case DataType::INT16: return ParseRangeNodeImpl(schema, field_name, body); case DataType::INT32: return ParseRangeNodeImpl(schema, field_name, body); case DataType::INT64: return ParseRangeNodeImpl(schema, field_name, body); case DataType::FLOAT: return ParseRangeNodeImpl(schema, field_name, body); case DataType::DOUBLE: return ParseRangeNodeImpl(schema, field_name, body); default: PanicInfo("unsupported"); } } static std::unique_ptr CreatePlanImplNaive(const Schema& schema, const std::string& dsl_str) { auto plan = std::make_unique(schema); auto dsl = nlohmann::json::parse(dsl_str); nlohmann::json vec_pack; std::optional> predicate; auto& bool_dsl = dsl.at("bool"); if (bool_dsl.contains("must")) { auto& packs = bool_dsl.at("must"); Assert(packs.is_array()); for (auto& pack : packs) { if (pack.contains("vector")) { auto& out_body = pack.at("vector"); plan->plan_node_ = ParseVecNode(plan.get(), out_body); } else if (pack.contains("range")) { AssertInfo(!predicate, "unsupported complex DSL"); auto& out_body = pack.at("range"); predicate = ParseRangeNode(schema, out_body); } else { PanicInfo("unsupported node"); } } AssertInfo(plan->plan_node_, "vector node not found"); } else if (bool_dsl.contains("vector")) { auto& out_body = bool_dsl.at("vector"); plan->plan_node_ = ParseVecNode(plan.get(), out_body); Assert(plan->plan_node_); } else { PanicInfo("Unsupported DSL"); } plan->plan_node_->predicate_ = std::move(predicate); // TODO: target_entry parser // if schema autoid is true, // prepend target_entries_ with row_id // else // with primary_key // return plan; } std::unique_ptr CreatePlan(const Schema& schema, const std::string& dsl_str) { auto plan = CreatePlanImplNaive(schema, dsl_str); return plan; } std::unique_ptr ParsePlaceholderGroup(const Plan* plan, const std::string& blob) { namespace ser = milvus::proto::service; auto result = std::make_unique(); ser::PlaceholderGroup ph_group; auto ok = ph_group.ParseFromString(blob); Assert(ok); for (auto& info : ph_group.placeholders()) { Placeholder element; element.tag_ = info.tag(); Assert(plan->tag2field_.count(element.tag_)); auto field_id = plan->tag2field_.at(element.tag_); auto& field_meta = plan->schema_[field_id]; element.num_of_queries_ = info.values_size(); AssertInfo(element.num_of_queries_, "must have queries"); Assert(element.num_of_queries_ > 0); element.line_sizeof_ = info.values().Get(0).size(); Assert(field_meta.get_sizeof() == element.line_sizeof_); auto& target = element.blob_; target.reserve(element.line_sizeof_ * element.num_of_queries_); for (auto& line : info.values()) { Assert(element.line_sizeof_ == line.size()); target.insert(target.end(), line.begin(), line.end()); } result->emplace_back(std::move(element)); } return result; } int64_t GetTopK(const Plan* plan) { return plan->plan_node_->query_info_.topK_; } int64_t GetNumOfQueries(const PlaceholderGroup* group) { return group->at(0).num_of_queries_; } } // namespace milvus::query