mirror of
https://gitee.com/milvus-io/milvus.git
synced 2025-12-06 17:18:35 +08:00
Using std::string instead of char * for PlaceholderGroup
Signed-off-by: FluorineDog <guilin.gou@zilliz.com>
This commit is contained in:
parent
4a3ca1632b
commit
f47fc7fef1
@ -21,6 +21,7 @@ ENDFOREACH(proto_file)
|
|||||||
add_library(milvus_proto STATIC
|
add_library(milvus_proto STATIC
|
||||||
${MILVUS_PROTO_SRCS}
|
${MILVUS_PROTO_SRCS}
|
||||||
)
|
)
|
||||||
|
message(${MILVUS_PROTO_SRCS})
|
||||||
|
|
||||||
target_link_libraries(milvus_proto
|
target_link_libraries(milvus_proto
|
||||||
libprotobuf
|
libprotobuf
|
||||||
|
|||||||
@ -5,8 +5,9 @@ set(MILVUS_QUERY_SRCS
|
|||||||
generated/Expr.cpp
|
generated/Expr.cpp
|
||||||
visitors/ShowPlanNodeVisitor.cpp
|
visitors/ShowPlanNodeVisitor.cpp
|
||||||
visitors/ExecPlanNodeVisitor.cpp
|
visitors/ExecPlanNodeVisitor.cpp
|
||||||
|
visitors/ShowExprVisitor.cpp
|
||||||
Parser.cpp
|
Parser.cpp
|
||||||
Plan.cpp
|
Plan.cpp
|
||||||
)
|
)
|
||||||
add_library(milvus_query ${MILVUS_QUERY_SRCS})
|
add_library(milvus_query ${MILVUS_QUERY_SRCS})
|
||||||
target_link_libraries(milvus_query libprotobuf)
|
target_link_libraries(milvus_query milvus_proto)
|
||||||
|
|||||||
@ -4,6 +4,8 @@
|
|||||||
#include <any>
|
#include <any>
|
||||||
#include <string>
|
#include <string>
|
||||||
#include <optional>
|
#include <optional>
|
||||||
|
#include "segcore/SegmentDefs.h"
|
||||||
|
|
||||||
namespace milvus::query {
|
namespace milvus::query {
|
||||||
class ExprVisitor;
|
class ExprVisitor;
|
||||||
|
|
||||||
@ -58,7 +60,13 @@ using FieldId = std::string;
|
|||||||
|
|
||||||
struct TermExpr : Expr {
|
struct TermExpr : Expr {
|
||||||
FieldId field_id_;
|
FieldId field_id_;
|
||||||
std::vector<std::any> terms_; //
|
segcore::DataType data_type_;
|
||||||
|
// std::vector<std::any> terms_;
|
||||||
|
|
||||||
|
protected:
|
||||||
|
// prevent accidential instantiation
|
||||||
|
TermExpr() = default;
|
||||||
|
|
||||||
public:
|
public:
|
||||||
void
|
void
|
||||||
accept(ExprVisitor&) override;
|
accept(ExprVisitor&) override;
|
||||||
@ -66,12 +74,14 @@ struct TermExpr : Expr {
|
|||||||
|
|
||||||
struct RangeExpr : Expr {
|
struct RangeExpr : Expr {
|
||||||
FieldId field_id_;
|
FieldId field_id_;
|
||||||
enum class OpType { LessThan, LessEqual, GreaterThan, GreaterEqual, Equal, NotEqual };
|
segcore::DataType data_type_;
|
||||||
std::vector<std::tuple<OpType, std::any>> conditions_;
|
// std::vector<std::tuple<OpType, std::any>> conditions_;
|
||||||
|
protected:
|
||||||
|
// prevent accidential instantiation
|
||||||
|
RangeExpr() = default;
|
||||||
|
|
||||||
public:
|
public:
|
||||||
void
|
void
|
||||||
accept(ExprVisitor&) override;
|
accept(ExprVisitor&) override;
|
||||||
};
|
};
|
||||||
|
|
||||||
} // namespace milvus::query
|
} // namespace milvus::query
|
||||||
|
|||||||
16
internal/core/src/query/ExprImpl.h
Normal file
16
internal/core/src/query/ExprImpl.h
Normal file
@ -0,0 +1,16 @@
|
|||||||
|
#pragma once
|
||||||
|
#include "Expr.h"
|
||||||
|
|
||||||
|
namespace milvus::query {
|
||||||
|
template <typename T>
|
||||||
|
struct TermExprImpl : TermExpr {
|
||||||
|
std::vector<T> terms_;
|
||||||
|
};
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
struct RangeExprImpl : RangeExpr {
|
||||||
|
enum class OpType { LessThan, LessEqual, GreaterThan, GreaterEqual, Equal, NotEqual };
|
||||||
|
std::vector<std::tuple<OpType, T>> conditions_;
|
||||||
|
};
|
||||||
|
|
||||||
|
} // namespace milvus::query
|
||||||
@ -23,7 +23,7 @@ CreateVec(const std::string& field_name, const json& vec_info) {
|
|||||||
|
|
||||||
static std::unique_ptr<Plan>
|
static std::unique_ptr<Plan>
|
||||||
CreatePlanImplNaive(const std::string& dsl_str) {
|
CreatePlanImplNaive(const std::string& dsl_str) {
|
||||||
auto plan = std::unique_ptr<Plan>();
|
auto plan = std::make_unique<Plan>();
|
||||||
auto dsl = nlohmann::json::parse(dsl_str);
|
auto dsl = nlohmann::json::parse(dsl_str);
|
||||||
nlohmann::json vec_pack;
|
nlohmann::json vec_pack;
|
||||||
|
|
||||||
@ -36,17 +36,19 @@ CreatePlanImplNaive(const std::string& dsl_str) {
|
|||||||
auto key = iter.key();
|
auto key = iter.key();
|
||||||
auto& body = iter.value();
|
auto& body = iter.value();
|
||||||
plan->plan_node_ = CreateVec(key, body);
|
plan->plan_node_ = CreateVec(key, body);
|
||||||
|
return plan;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
PanicInfo("Unsupported DSL: vector node not detected");
|
||||||
} else if (bool_dsl.contains("vector")) {
|
} else if (bool_dsl.contains("vector")) {
|
||||||
auto iter = bool_dsl["vector"].begin();
|
auto iter = bool_dsl["vector"].begin();
|
||||||
auto key = iter.key();
|
auto key = iter.key();
|
||||||
auto& body = iter.value();
|
auto& body = iter.value();
|
||||||
plan->plan_node_ = CreateVec(key, body);
|
plan->plan_node_ = CreateVec(key, body);
|
||||||
|
return plan;
|
||||||
} else {
|
} else {
|
||||||
PanicInfo("Unsupported DSL: vector node not detected");
|
PanicInfo("Unsupported DSL: vector node not detected");
|
||||||
}
|
}
|
||||||
return plan;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
void
|
void
|
||||||
@ -55,6 +57,7 @@ CheckNull(const Json& json) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
class PlanParser {
|
class PlanParser {
|
||||||
|
public:
|
||||||
void
|
void
|
||||||
ParseBoolBody(const Json& dsl) {
|
ParseBoolBody(const Json& dsl) {
|
||||||
CheckNull(dsl);
|
CheckNull(dsl);
|
||||||
@ -74,6 +77,8 @@ class PlanParser {
|
|||||||
}
|
}
|
||||||
PanicInfo("unimplemented");
|
PanicInfo("unimplemented");
|
||||||
}
|
}
|
||||||
|
|
||||||
|
private:
|
||||||
};
|
};
|
||||||
|
|
||||||
std::unique_ptr<Plan>
|
std::unique_ptr<Plan>
|
||||||
@ -83,11 +88,12 @@ CreatePlan(const std::string& dsl_str) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
std::unique_ptr<PlaceholderGroup>
|
std::unique_ptr<PlaceholderGroup>
|
||||||
ParsePlaceholderGroup(const char* placeholder_group_blob) {
|
ParsePlaceholderGroup(const std::string& blob) {
|
||||||
namespace ser = milvus::proto::service;
|
namespace ser = milvus::proto::service;
|
||||||
auto result = std::unique_ptr<PlaceholderGroup>();
|
auto result = std::make_unique<PlaceholderGroup>();
|
||||||
ser::PlaceholderGroup ph_group;
|
ser::PlaceholderGroup ph_group;
|
||||||
GOOGLE_PROTOBUF_PARSER_ASSERT(ph_group.ParseFromString(placeholder_group_blob));
|
auto ok = ph_group.ParseFromString(blob);
|
||||||
|
Assert(ok);
|
||||||
for (auto& info : ph_group.placeholders()) {
|
for (auto& info : ph_group.placeholders()) {
|
||||||
Placeholder element;
|
Placeholder element;
|
||||||
element.tag_ = info.tag();
|
element.tag_ = info.tag();
|
||||||
|
|||||||
@ -13,7 +13,7 @@ std::unique_ptr<Plan>
|
|||||||
CreatePlan(const std::string& dsl);
|
CreatePlan(const std::string& dsl);
|
||||||
|
|
||||||
std::unique_ptr<PlaceholderGroup>
|
std::unique_ptr<PlaceholderGroup>
|
||||||
ParsePlaceholderGroup(const char* placeholder_group_blob);
|
ParsePlaceholderGroup(const std::string& placeholder_group_blob);
|
||||||
|
|
||||||
int64_t
|
int64_t
|
||||||
GetNumOfQueries(const PlaceholderGroup*);
|
GetNumOfQueries(const PlaceholderGroup*);
|
||||||
@ -24,3 +24,5 @@ int64_t
|
|||||||
GetTopK(const Plan*);
|
GetTopK(const Plan*);
|
||||||
|
|
||||||
} // namespace milvus::query
|
} // namespace milvus::query
|
||||||
|
|
||||||
|
#include "PlanImpl.h"
|
||||||
@ -28,7 +28,6 @@ struct PlanNode {
|
|||||||
using PlanNodePtr = std::unique_ptr<PlanNode>;
|
using PlanNodePtr = std::unique_ptr<PlanNode>;
|
||||||
|
|
||||||
struct QueryInfo {
|
struct QueryInfo {
|
||||||
int64_t num_queries_;
|
|
||||||
int64_t topK_;
|
int64_t topK_;
|
||||||
FieldId field_id_;
|
FieldId field_id_;
|
||||||
std::string metric_type_; // TODO: use enum
|
std::string metric_type_; // TODO: use enum
|
||||||
|
|||||||
@ -1,7 +1,11 @@
|
|||||||
#pragma once
|
#pragma once
|
||||||
// Generated File
|
// Generated File
|
||||||
// DO NOT EDIT
|
// DO NOT EDIT
|
||||||
|
#include "utils/Json.h"
|
||||||
|
#include "query/PlanImpl.h"
|
||||||
|
#include "segcore/SegmentBase.h"
|
||||||
#include "PlanNodeVisitor.h"
|
#include "PlanNodeVisitor.h"
|
||||||
|
|
||||||
namespace milvus::query {
|
namespace milvus::query {
|
||||||
class ExecPlanNodeVisitor : PlanNodeVisitor {
|
class ExecPlanNodeVisitor : PlanNodeVisitor {
|
||||||
public:
|
public:
|
||||||
|
|||||||
@ -1,7 +1,11 @@
|
|||||||
#pragma once
|
#pragma once
|
||||||
// Generated File
|
// Generated File
|
||||||
// DO NOT EDIT
|
// DO NOT EDIT
|
||||||
|
#include "query/Plan.h"
|
||||||
|
#include "utils/EasyAssert.h"
|
||||||
|
#include "utils/Json.h"
|
||||||
#include "ExprVisitor.h"
|
#include "ExprVisitor.h"
|
||||||
|
|
||||||
namespace milvus::query {
|
namespace milvus::query {
|
||||||
class ShowExprVisitor : ExprVisitor {
|
class ShowExprVisitor : ExprVisitor {
|
||||||
public:
|
public:
|
||||||
@ -18,5 +22,35 @@ class ShowExprVisitor : ExprVisitor {
|
|||||||
visit(RangeExpr& expr) override;
|
visit(RangeExpr& expr) override;
|
||||||
|
|
||||||
public:
|
public:
|
||||||
|
using RetType = Json;
|
||||||
|
|
||||||
|
public:
|
||||||
|
RetType
|
||||||
|
call_child(Expr& expr) {
|
||||||
|
assert(!ret_.has_value());
|
||||||
|
expr.accept(*this);
|
||||||
|
assert(ret_.has_value());
|
||||||
|
auto ret = std::move(ret_);
|
||||||
|
ret_ = std::nullopt;
|
||||||
|
return std::move(ret.value());
|
||||||
|
}
|
||||||
|
|
||||||
|
Json
|
||||||
|
combine(Json&& extra, UnaryExpr& expr) {
|
||||||
|
auto result = std::move(extra);
|
||||||
|
result["child"] = call_child(*expr.child_);
|
||||||
|
return result;
|
||||||
|
}
|
||||||
|
|
||||||
|
Json
|
||||||
|
combine(Json&& extra, BinaryExpr& expr) {
|
||||||
|
auto result = std::move(extra);
|
||||||
|
result["left_child"] = call_child(*expr.left_);
|
||||||
|
result["right_child"] = call_child(*expr.right_);
|
||||||
|
return result;
|
||||||
|
}
|
||||||
|
|
||||||
|
private:
|
||||||
|
std::optional<RetType> ret_;
|
||||||
};
|
};
|
||||||
} // namespace milvus::query
|
} // namespace milvus::query
|
||||||
|
|||||||
@ -1,7 +1,12 @@
|
|||||||
#pragma once
|
#pragma once
|
||||||
// Generated File
|
// Generated File
|
||||||
// DO NOT EDIT
|
// DO NOT EDIT
|
||||||
|
#include "utils/EasyAssert.h"
|
||||||
|
#include "utils/Json.h"
|
||||||
|
#include <optional>
|
||||||
|
|
||||||
#include "PlanNodeVisitor.h"
|
#include "PlanNodeVisitor.h"
|
||||||
|
|
||||||
namespace milvus::query {
|
namespace milvus::query {
|
||||||
class ShowPlanNodeVisitor : PlanNodeVisitor {
|
class ShowPlanNodeVisitor : PlanNodeVisitor {
|
||||||
public:
|
public:
|
||||||
@ -21,6 +26,7 @@ class ShowPlanNodeVisitor : PlanNodeVisitor {
|
|||||||
node.accept(*this);
|
node.accept(*this);
|
||||||
assert(ret_.has_value());
|
assert(ret_.has_value());
|
||||||
auto ret = std::move(ret_);
|
auto ret = std::move(ret_);
|
||||||
|
ret_ = std::nullopt;
|
||||||
return std::move(ret.value());
|
return std::move(ret.value());
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@ -48,8 +48,10 @@ ExecPlanNodeVisitor::visit(FloatVectorANNS& node) {
|
|||||||
auto segment = dynamic_cast<segcore::SegmentSmallIndex*>(&segment_);
|
auto segment = dynamic_cast<segcore::SegmentSmallIndex*>(&segment_);
|
||||||
AssertInfo(segment, "support SegmentSmallIndex Only");
|
AssertInfo(segment, "support SegmentSmallIndex Only");
|
||||||
RetType ret;
|
RetType ret;
|
||||||
auto src_data = placeholder_group_.at(0).get_blob<float>();
|
auto& ph = placeholder_group_.at(0);
|
||||||
segment->QueryBruteForceImpl(node.query_info_, src_data, timestamp_, ret);
|
auto src_data = ph.get_blob<float>();
|
||||||
|
auto num_queries = ph.num_of_queries_;
|
||||||
|
segment->QueryBruteForceImpl(node.query_info_, src_data, num_queries, timestamp_, ret);
|
||||||
ret_ = ret;
|
ret_ = ret;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
173
internal/core/src/query/visitors/ShowExprVisitor.cpp
Normal file
173
internal/core/src/query/visitors/ShowExprVisitor.cpp
Normal file
@ -0,0 +1,173 @@
|
|||||||
|
#include "query/Plan.h"
|
||||||
|
#include "utils/EasyAssert.h"
|
||||||
|
#include "utils/Json.h"
|
||||||
|
#include "query/generated/ShowExprVisitor.h"
|
||||||
|
#include "query/ExprImpl.h"
|
||||||
|
|
||||||
|
namespace milvus::query {
|
||||||
|
using Json = nlohmann::json;
|
||||||
|
|
||||||
|
#if 1
|
||||||
|
// THIS CONTAINS EXTRA BODY FOR VISITOR
|
||||||
|
// WILL BE USED BY GENERATOR
|
||||||
|
namespace impl {
|
||||||
|
class ShowExprNodeVisitor : ExprVisitor {
|
||||||
|
public:
|
||||||
|
using RetType = Json;
|
||||||
|
|
||||||
|
public:
|
||||||
|
RetType
|
||||||
|
call_child(Expr& expr) {
|
||||||
|
assert(!ret_.has_value());
|
||||||
|
expr.accept(*this);
|
||||||
|
assert(ret_.has_value());
|
||||||
|
auto ret = std::move(ret_);
|
||||||
|
ret_ = std::nullopt;
|
||||||
|
return std::move(ret.value());
|
||||||
|
}
|
||||||
|
|
||||||
|
Json
|
||||||
|
combine(Json&& extra, UnaryExpr& expr) {
|
||||||
|
auto result = std::move(extra);
|
||||||
|
result["child"] = call_child(*expr.child_);
|
||||||
|
return result;
|
||||||
|
}
|
||||||
|
|
||||||
|
Json
|
||||||
|
combine(Json&& extra, BinaryExpr& expr) {
|
||||||
|
auto result = std::move(extra);
|
||||||
|
result["left_child"] = call_child(*expr.left_);
|
||||||
|
result["right_child"] = call_child(*expr.right_);
|
||||||
|
return result;
|
||||||
|
}
|
||||||
|
|
||||||
|
private:
|
||||||
|
std::optional<RetType> ret_;
|
||||||
|
};
|
||||||
|
} // namespace impl
|
||||||
|
#endif
|
||||||
|
|
||||||
|
void
|
||||||
|
ShowExprVisitor::visit(BoolUnaryExpr& expr) {
|
||||||
|
Assert(!ret_.has_value());
|
||||||
|
using OpType = BoolUnaryExpr::OpType;
|
||||||
|
|
||||||
|
// TODO: use magic_enum if available
|
||||||
|
Assert(expr.op_type_ == OpType::LogicalNot);
|
||||||
|
auto op_name = "LogicalNot";
|
||||||
|
|
||||||
|
Json extra{
|
||||||
|
{"expr_type", "BoolUnary"},
|
||||||
|
{"op", op_name},
|
||||||
|
};
|
||||||
|
ret_ = this->combine(std::move(extra), expr);
|
||||||
|
}
|
||||||
|
|
||||||
|
void
|
||||||
|
ShowExprVisitor::visit(BoolBinaryExpr& expr) {
|
||||||
|
Assert(!ret_.has_value());
|
||||||
|
using OpType = BoolBinaryExpr::OpType;
|
||||||
|
|
||||||
|
// TODO: use magic_enum if available
|
||||||
|
auto op_name = [](OpType op) {
|
||||||
|
switch (op) {
|
||||||
|
case OpType::LogicalAnd:
|
||||||
|
return "LogicalAnd";
|
||||||
|
case OpType::LogicalOr:
|
||||||
|
return "LogicalOr";
|
||||||
|
case OpType::LogicalXor:
|
||||||
|
return "LogicalXor";
|
||||||
|
default:
|
||||||
|
PanicInfo("unsupported op");
|
||||||
|
}
|
||||||
|
}(expr.op_type_);
|
||||||
|
|
||||||
|
Json extra{
|
||||||
|
{"expr_type", "BoolBinary"},
|
||||||
|
{"op", op_name},
|
||||||
|
};
|
||||||
|
ret_ = this->combine(std::move(extra), expr);
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
static Json
|
||||||
|
TermExtract(const TermExpr& expr_raw) {
|
||||||
|
auto expr = dynamic_cast<const TermExprImpl<T>*>(&expr_raw);
|
||||||
|
Assert(expr);
|
||||||
|
return Json{expr->terms_};
|
||||||
|
}
|
||||||
|
|
||||||
|
void
|
||||||
|
ShowExprVisitor::visit(TermExpr& expr) {
|
||||||
|
Assert(!ret_.has_value());
|
||||||
|
Assert(segcore::field_is_vector(expr.data_type_) == false);
|
||||||
|
using segcore::DataType;
|
||||||
|
auto terms = [&] {
|
||||||
|
switch (expr.data_type_) {
|
||||||
|
case DataType::INT8:
|
||||||
|
return TermExtract<int8_t>(expr);
|
||||||
|
case DataType::INT16:
|
||||||
|
return TermExtract<int16_t>(expr);
|
||||||
|
case DataType::INT32:
|
||||||
|
return TermExtract<int32_t>(expr);
|
||||||
|
case DataType::INT64:
|
||||||
|
return TermExtract<int64_t>(expr);
|
||||||
|
case DataType::DOUBLE:
|
||||||
|
return TermExtract<double>(expr);
|
||||||
|
case DataType::FLOAT:
|
||||||
|
return TermExtract<float>(expr);
|
||||||
|
case DataType::BOOL:
|
||||||
|
return TermExtract<bool>(expr);
|
||||||
|
default:
|
||||||
|
PanicInfo("unsupported type");
|
||||||
|
}
|
||||||
|
}();
|
||||||
|
|
||||||
|
Json res{{"expr_type", "Term"},
|
||||||
|
{"field_id", expr.field_id_},
|
||||||
|
{"data_type", segcore::datatype_name(expr.data_type_)},
|
||||||
|
{"terms", std::move(terms)}};
|
||||||
|
|
||||||
|
ret_ = res;
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
static Json
|
||||||
|
CondtionExtract(const RangeExpr& expr_raw) {
|
||||||
|
auto expr = dynamic_cast<const TermExprImpl<T>*>(&expr_raw);
|
||||||
|
Assert(expr);
|
||||||
|
return Json{expr->terms_};
|
||||||
|
}
|
||||||
|
|
||||||
|
void
|
||||||
|
ShowExprVisitor::visit(RangeExpr& expr) {
|
||||||
|
Assert(!ret_.has_value());
|
||||||
|
Assert(segcore::field_is_vector(expr.data_type_) == false);
|
||||||
|
using segcore::DataType;
|
||||||
|
auto conditions = [&] {
|
||||||
|
switch (expr.data_type_) {
|
||||||
|
case DataType::BOOL:
|
||||||
|
return CondtionExtract<bool>(expr);
|
||||||
|
case DataType::INT8:
|
||||||
|
return CondtionExtract<int8_t>(expr);
|
||||||
|
case DataType::INT16:
|
||||||
|
return CondtionExtract<int16_t>(expr);
|
||||||
|
case DataType::INT32:
|
||||||
|
return CondtionExtract<int32_t>(expr);
|
||||||
|
case DataType::INT64:
|
||||||
|
return CondtionExtract<int64_t>(expr);
|
||||||
|
case DataType::DOUBLE:
|
||||||
|
return CondtionExtract<double>(expr);
|
||||||
|
case DataType::FLOAT:
|
||||||
|
return CondtionExtract<float>(expr);
|
||||||
|
default:
|
||||||
|
PanicInfo("unsupported type");
|
||||||
|
}
|
||||||
|
}();
|
||||||
|
|
||||||
|
Json res{{"expr_type", "Range"},
|
||||||
|
{"field_id", expr.field_id_},
|
||||||
|
{"data_type", segcore::datatype_name(expr.data_type_)},
|
||||||
|
{"conditions", std::move(conditions)}};
|
||||||
|
}
|
||||||
|
} // namespace milvus::query
|
||||||
@ -19,6 +19,7 @@ class ShowPlanNodeVisitorImpl : PlanNodeVisitor {
|
|||||||
node.accept(*this);
|
node.accept(*this);
|
||||||
assert(ret_.has_value());
|
assert(ret_.has_value());
|
||||||
auto ret = std::move(ret_);
|
auto ret = std::move(ret_);
|
||||||
|
ret_ = std::nullopt;
|
||||||
return std::move(ret.value());
|
return std::move(ret.value());
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -42,9 +43,7 @@ ShowPlanNodeVisitor::visit(FloatVectorANNS& node) {
|
|||||||
Json json_body{
|
Json json_body{
|
||||||
{"node_type", "FloatVectorANNS"}, //
|
{"node_type", "FloatVectorANNS"}, //
|
||||||
{"metric_type", info.metric_type_}, //
|
{"metric_type", info.metric_type_}, //
|
||||||
// {"dim", info.dim_}, //
|
|
||||||
{"field_id_", info.field_id_}, //
|
{"field_id_", info.field_id_}, //
|
||||||
{"num_queries", info.num_queries_}, //
|
|
||||||
{"topK", info.topK_}, //
|
{"topK", info.topK_}, //
|
||||||
{"search_params", info.search_params_}, //
|
{"search_params", info.search_params_}, //
|
||||||
{"placeholder_tag", node.placeholder_tag_}, //
|
{"placeholder_tag", node.placeholder_tag_}, //
|
||||||
@ -52,7 +51,7 @@ ShowPlanNodeVisitor::visit(FloatVectorANNS& node) {
|
|||||||
if (node.predicate_.has_value()) {
|
if (node.predicate_.has_value()) {
|
||||||
PanicInfo("unimplemented");
|
PanicInfo("unimplemented");
|
||||||
} else {
|
} else {
|
||||||
json_body["predicate"] = "nullopt";
|
json_body["predicate"] = "None";
|
||||||
}
|
}
|
||||||
ret_ = json_body;
|
ret_ = json_body;
|
||||||
}
|
}
|
||||||
|
|||||||
@ -50,6 +50,36 @@ field_sizeof(DataType data_type, int dim = 1) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// TODO: use magic_enum when available
|
||||||
|
inline std::string
|
||||||
|
datatype_name(DataType data_type) {
|
||||||
|
switch (data_type) {
|
||||||
|
case DataType::BOOL:
|
||||||
|
return "bool";
|
||||||
|
case DataType::DOUBLE:
|
||||||
|
return "double";
|
||||||
|
case DataType::FLOAT:
|
||||||
|
return "float";
|
||||||
|
case DataType::INT8:
|
||||||
|
return "int8_t";
|
||||||
|
case DataType::INT16:
|
||||||
|
return "int16_t";
|
||||||
|
case DataType::INT32:
|
||||||
|
return "int32_t";
|
||||||
|
case DataType::INT64:
|
||||||
|
return "int64_t";
|
||||||
|
case DataType::VECTOR_FLOAT:
|
||||||
|
return "vector_float";
|
||||||
|
case DataType::VECTOR_BINARY: {
|
||||||
|
return "vector_binary";
|
||||||
|
}
|
||||||
|
default: {
|
||||||
|
auto err_msg = "Unsupported DataType(" + std::to_string((int)data_type) + ")";
|
||||||
|
PanicInfo(err_msg);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
inline bool
|
inline bool
|
||||||
field_is_vector(DataType datatype) {
|
field_is_vector(DataType datatype) {
|
||||||
return datatype == DataType::VECTOR_BINARY || datatype == DataType::VECTOR_FLOAT;
|
return datatype == DataType::VECTOR_BINARY || datatype == DataType::VECTOR_FLOAT;
|
||||||
|
|||||||
@ -223,6 +223,7 @@ get_barrier(const RecordType& record, Timestamp timestamp) {
|
|||||||
Status
|
Status
|
||||||
SegmentSmallIndex::QueryBruteForceImpl(const query::QueryInfo& info,
|
SegmentSmallIndex::QueryBruteForceImpl(const query::QueryInfo& info,
|
||||||
const float* query_data,
|
const float* query_data,
|
||||||
|
int64_t num_queries,
|
||||||
Timestamp timestamp,
|
Timestamp timestamp,
|
||||||
QueryResult& results) {
|
QueryResult& results) {
|
||||||
// step 1: binary search to find the barrier of the snapshot
|
// step 1: binary search to find the barrier of the snapshot
|
||||||
@ -247,7 +248,6 @@ SegmentSmallIndex::QueryBruteForceImpl(const query::QueryInfo& info,
|
|||||||
Assert(field.get_data_type() == DataType::VECTOR_FLOAT);
|
Assert(field.get_data_type() == DataType::VECTOR_FLOAT);
|
||||||
auto dim = field.get_dim();
|
auto dim = field.get_dim();
|
||||||
auto topK = info.topK_;
|
auto topK = info.topK_;
|
||||||
auto num_queries = info.num_queries_;
|
|
||||||
auto total_count = topK * num_queries;
|
auto total_count = topK * num_queries;
|
||||||
// TODO: optimize
|
// TODO: optimize
|
||||||
|
|
||||||
@ -321,7 +321,6 @@ SegmentSmallIndex::QueryDeprecated(query::QueryDeprecatedPtr query_info, Timesta
|
|||||||
int64_t inferred_dim = query_info->query_raw_data.size() / query_info->num_queries;
|
int64_t inferred_dim = query_info->query_raw_data.size() / query_info->num_queries;
|
||||||
// TODO
|
// TODO
|
||||||
query::QueryInfo info{
|
query::QueryInfo info{
|
||||||
query_info->num_queries,
|
|
||||||
query_info->topK,
|
query_info->topK,
|
||||||
query_info->field_name,
|
query_info->field_name,
|
||||||
"L2",
|
"L2",
|
||||||
@ -329,7 +328,8 @@ SegmentSmallIndex::QueryDeprecated(query::QueryDeprecatedPtr query_info, Timesta
|
|||||||
{"nprobe", 10},
|
{"nprobe", 10},
|
||||||
},
|
},
|
||||||
};
|
};
|
||||||
return QueryBruteForceImpl(info, query_info->query_raw_data.data(), timestamp, result);
|
auto num_queries = query_info->num_queries;
|
||||||
|
return QueryBruteForceImpl(info, query_info->query_raw_data.data(), num_queries, timestamp, result);
|
||||||
}
|
}
|
||||||
|
|
||||||
Status
|
Status
|
||||||
@ -453,14 +453,15 @@ SegmentSmallIndex::GetMemoryUsageInBytes() {
|
|||||||
}
|
}
|
||||||
|
|
||||||
Status
|
Status
|
||||||
SegmentSmallIndex::Search(const query::Plan* Plan,
|
SegmentSmallIndex::Search(const query::Plan* plan,
|
||||||
const query::PlaceholderGroup** placeholder_groups,
|
const query::PlaceholderGroup** placeholder_groups,
|
||||||
const Timestamp* timestamps,
|
const Timestamp* timestamps,
|
||||||
int num_groups,
|
int num_groups,
|
||||||
QueryResult& results) {
|
QueryResult& results) {
|
||||||
Assert(num_groups == 1);
|
Assert(num_groups == 1);
|
||||||
query::ExecPlanNodeVisitor visitor(*this, timestamps[0], *placeholder_groups[0]);
|
query::ExecPlanNodeVisitor visitor(*this, timestamps[0], *placeholder_groups[0]);
|
||||||
PanicInfo("unimplemented");
|
results = visitor.get_moved_result(*plan->plan_node_);
|
||||||
|
return Status::OK();
|
||||||
}
|
}
|
||||||
|
|
||||||
} // namespace milvus::segcore
|
} // namespace milvus::segcore
|
||||||
|
|||||||
@ -143,6 +143,7 @@ class SegmentSmallIndex : public SegmentBase {
|
|||||||
Status
|
Status
|
||||||
QueryBruteForceImpl(const query::QueryInfo& info,
|
QueryBruteForceImpl(const query::QueryInfo& info,
|
||||||
const float* query_data,
|
const float* query_data,
|
||||||
|
int64_t num_queries,
|
||||||
Timestamp timestamp,
|
Timestamp timestamp,
|
||||||
QueryResult& results);
|
QueryResult& results);
|
||||||
|
|
||||||
|
|||||||
@ -9,6 +9,7 @@ set(MILVUS_TEST_FILES
|
|||||||
test_c_api.cpp
|
test_c_api.cpp
|
||||||
test_indexing.cpp
|
test_indexing.cpp
|
||||||
test_query.cpp
|
test_query.cpp
|
||||||
|
test_expr.cpp
|
||||||
)
|
)
|
||||||
add_executable(all_tests
|
add_executable(all_tests
|
||||||
${MILVUS_TEST_FILES}
|
${MILVUS_TEST_FILES}
|
||||||
|
|||||||
70
internal/core/unittest/test_expr.cpp
Normal file
70
internal/core/unittest/test_expr.cpp
Normal file
@ -0,0 +1,70 @@
|
|||||||
|
#include <gtest/gtest.h>
|
||||||
|
#include "query/Parser.h"
|
||||||
|
#include "query/Expr.h"
|
||||||
|
#include "query/PlanNode.h"
|
||||||
|
#include "query/generated/ExprVisitor.h"
|
||||||
|
#include "query/generated/PlanNodeVisitor.h"
|
||||||
|
#include "test_utils/DataGen.h"
|
||||||
|
#include "query/generated/ShowPlanNodeVisitor.h"
|
||||||
|
|
||||||
|
TEST(Expr, Naive) {
|
||||||
|
SUCCEED();
|
||||||
|
using namespace milvus::wtf;
|
||||||
|
std::string dsl_string = R"(
|
||||||
|
{
|
||||||
|
"bool": {
|
||||||
|
"must": [
|
||||||
|
{
|
||||||
|
"term": {
|
||||||
|
"A": [
|
||||||
|
1,
|
||||||
|
2,
|
||||||
|
5
|
||||||
|
]
|
||||||
|
}
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"range": {
|
||||||
|
"B": {
|
||||||
|
"GT": 1,
|
||||||
|
"LT": 100
|
||||||
|
}
|
||||||
|
}
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"vector": {
|
||||||
|
"Vec": {
|
||||||
|
"metric_type": "L2",
|
||||||
|
"params": {
|
||||||
|
"nprobe": 10
|
||||||
|
},
|
||||||
|
"query": "$0",
|
||||||
|
"topk": 10
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
]
|
||||||
|
}
|
||||||
|
})";
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST(Expr, ShowExecutor) {
|
||||||
|
using namespace milvus::query;
|
||||||
|
using namespace milvus::segcore;
|
||||||
|
auto node = std::make_unique<FloatVectorANNS>();
|
||||||
|
auto schema = std::make_shared<Schema>();
|
||||||
|
int64_t num_queries = 100L;
|
||||||
|
schema->AddField("fakevec", DataType::VECTOR_FLOAT, 16);
|
||||||
|
auto raw_data = DataGen(schema, num_queries);
|
||||||
|
auto& info = node->query_info_;
|
||||||
|
info.metric_type_ = "L2";
|
||||||
|
info.topK_ = 20;
|
||||||
|
info.field_id_ = "fakevec";
|
||||||
|
node->predicate_ = std::nullopt;
|
||||||
|
ShowPlanNodeVisitor show_visitor;
|
||||||
|
PlanNodePtr base(node.release());
|
||||||
|
auto res = show_visitor.call_child(*base);
|
||||||
|
auto dup = res;
|
||||||
|
dup["data"] = "...collased...";
|
||||||
|
std::cout << dup.dump(4);
|
||||||
|
}
|
||||||
@ -6,6 +6,8 @@
|
|||||||
#include "query/generated/PlanNodeVisitor.h"
|
#include "query/generated/PlanNodeVisitor.h"
|
||||||
#include "test_utils/DataGen.h"
|
#include "test_utils/DataGen.h"
|
||||||
#include "query/generated/ShowPlanNodeVisitor.h"
|
#include "query/generated/ShowPlanNodeVisitor.h"
|
||||||
|
#include "query/generated/ExecPlanNodeVisitor.h"
|
||||||
|
#include "query/PlanImpl.h"
|
||||||
|
|
||||||
TEST(Query, Naive) {
|
TEST(Query, Naive) {
|
||||||
SUCCEED();
|
SUCCEED();
|
||||||
@ -58,7 +60,6 @@ TEST(Query, ShowExecutor) {
|
|||||||
auto raw_data = DataGen(schema, num_queries);
|
auto raw_data = DataGen(schema, num_queries);
|
||||||
auto& info = node->query_info_;
|
auto& info = node->query_info_;
|
||||||
info.metric_type_ = "L2";
|
info.metric_type_ = "L2";
|
||||||
info.num_queries_ = 10;
|
|
||||||
info.topK_ = 20;
|
info.topK_ = 20;
|
||||||
info.field_id_ = "fakevec";
|
info.field_id_ = "fakevec";
|
||||||
node->predicate_ = std::nullopt;
|
node->predicate_ = std::nullopt;
|
||||||
@ -66,6 +67,87 @@ TEST(Query, ShowExecutor) {
|
|||||||
PlanNodePtr base(node.release());
|
PlanNodePtr base(node.release());
|
||||||
auto res = show_visitor.call_child(*base);
|
auto res = show_visitor.call_child(*base);
|
||||||
auto dup = res;
|
auto dup = res;
|
||||||
dup["data"] = "...collased...";
|
|
||||||
std::cout << dup.dump(4);
|
std::cout << dup.dump(4);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
TEST(Query, DSL) {
|
||||||
|
using namespace milvus::query;
|
||||||
|
using namespace milvus::segcore;
|
||||||
|
ShowPlanNodeVisitor shower;
|
||||||
|
|
||||||
|
std::string dsl_string = R"(
|
||||||
|
{
|
||||||
|
"bool": {
|
||||||
|
"must": [
|
||||||
|
{
|
||||||
|
"vector": {
|
||||||
|
"Vec": {
|
||||||
|
"metric_type": "L2",
|
||||||
|
"params": {
|
||||||
|
"nprobe": 10
|
||||||
|
},
|
||||||
|
"query": "$0",
|
||||||
|
"topk": 10
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
]
|
||||||
|
}
|
||||||
|
})";
|
||||||
|
auto plan = CreatePlan(dsl_string);
|
||||||
|
auto res = shower.call_child(*plan->plan_node_);
|
||||||
|
std::cout << res.dump(4) << std::endl;
|
||||||
|
|
||||||
|
std::string dsl_string2 = R"(
|
||||||
|
{
|
||||||
|
"bool": {
|
||||||
|
"vector": {
|
||||||
|
"Vec": {
|
||||||
|
"metric_type": "L2",
|
||||||
|
"params": {
|
||||||
|
"nprobe": 10
|
||||||
|
},
|
||||||
|
"query": "$0",
|
||||||
|
"topk": 10
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
})";
|
||||||
|
auto plan2 = CreatePlan(dsl_string2);
|
||||||
|
auto res2 = shower.call_child(*plan2->plan_node_);
|
||||||
|
std::cout << res2.dump(4) << std::endl;
|
||||||
|
ASSERT_EQ(res, res2);
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST(Query, ParsePlaceholderGroup) {
|
||||||
|
using namespace milvus::query;
|
||||||
|
using namespace milvus::segcore;
|
||||||
|
namespace ser = milvus::proto::service;
|
||||||
|
int num_queries = 10;
|
||||||
|
int dim = 16;
|
||||||
|
std::default_random_engine e;
|
||||||
|
std::normal_distribution<double> dis(0, 1);
|
||||||
|
ser::PlaceholderGroup raw_group;
|
||||||
|
auto value = raw_group.add_placeholders();
|
||||||
|
value->set_tag("$0");
|
||||||
|
value->set_type(ser::PlaceholderType::VECTOR_FLOAT);
|
||||||
|
for(int i = 0; i < num_queries; ++i) {
|
||||||
|
std::vector<float> vec;
|
||||||
|
for(int d = 0; d < dim; ++d) {
|
||||||
|
vec.push_back(dis(e));
|
||||||
|
}
|
||||||
|
// std::string line((char*)vec.data(), (char*)vec.data() + vec.size() * sizeof(float));
|
||||||
|
value->add_values(vec.data(), vec.size() * sizeof(float));
|
||||||
|
}
|
||||||
|
auto blob = raw_group.SerializeAsString();
|
||||||
|
//ser::PlaceholderGroup new_group;
|
||||||
|
//new_group.ParseFromString()
|
||||||
|
auto fuck = ParsePlaceholderGroup(blob);
|
||||||
|
int x = 1+1;
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
TEST(Query, Exec) {
|
||||||
|
using namespace milvus::query;
|
||||||
|
using namespace milvus::segcore;
|
||||||
|
}
|
||||||
|
|||||||
@ -25,13 +25,20 @@ for UNITTEST_DIR in "${UNITTEST_DIRS[@]}"; do
|
|||||||
echo "The unittest folder does not exist!"
|
echo "The unittest folder does not exist!"
|
||||||
exit 1
|
exit 1
|
||||||
fi
|
fi
|
||||||
for test in `ls ${UNITTEST_DIR}`; do
|
|
||||||
echo $test " running..."
|
${UNITTEST_DIR}/all_tests
|
||||||
# run unittest
|
|
||||||
${UNITTEST_DIR}/${test}
|
|
||||||
if [ $? -ne 0 ]; then
|
if [ $? -ne 0 ]; then
|
||||||
echo ${UNITTEST_DIR}/${test} "run failed"
|
echo ${UNITTEST_DIR}/all_tests "run failed"
|
||||||
exit 1
|
exit 1
|
||||||
fi
|
fi
|
||||||
done
|
|
||||||
|
#for test in `ls ${UNITTEST_DIR}`; do
|
||||||
|
# echo $test " running..."
|
||||||
|
# # run unittest
|
||||||
|
# ${UNITTEST_DIR}/${test}
|
||||||
|
# if [ $? -ne 0 ]; then
|
||||||
|
# echo ${UNITTEST_DIR}/${test} "run failed"
|
||||||
|
# exit 1
|
||||||
|
# fi
|
||||||
|
#done
|
||||||
done
|
done
|
||||||
|
|||||||
@ -14,7 +14,7 @@ def gen_file(rootfile, template, output, **kwargs):
|
|||||||
|
|
||||||
|
|
||||||
def extract_extra_body(visitor_info, query_path):
|
def extract_extra_body(visitor_info, query_path):
|
||||||
pattern = re.compile("class(.*){\n((.|\n)*?)\n};", re.MULTILINE)
|
pattern = re.compile(r"class(.*){\n((.|\n)*?)\n};", re.MULTILINE)
|
||||||
|
|
||||||
for node, visitors in visitor_info.items():
|
for node, visitors in visitor_info.items():
|
||||||
for visitor in visitors:
|
for visitor in visitors:
|
||||||
@ -22,11 +22,24 @@ def extract_extra_body(visitor_info, query_path):
|
|||||||
vis_file = query_path + "visitors/" + vis_name + ".cpp"
|
vis_file = query_path + "visitors/" + vis_name + ".cpp"
|
||||||
body = ' public:'
|
body = ' public:'
|
||||||
|
|
||||||
|
inc_pattern_str = r'^(#include(.|\n)*)\n#include "query/generated/{}.h"'.format(vis_name)
|
||||||
|
inc_pattern = re.compile(inc_pattern_str, re.MULTILINE)
|
||||||
|
|
||||||
if os.path.exists(vis_file):
|
if os.path.exists(vis_file):
|
||||||
infos = pattern.findall(readfile(vis_file))
|
content = readfile(vis_file)
|
||||||
|
infos = pattern.findall(content)
|
||||||
|
assert len(infos) <= 1
|
||||||
if len(infos) == 1:
|
if len(infos) == 1:
|
||||||
name, body, _ = infos[0]
|
name, body, _ = infos[0]
|
||||||
|
|
||||||
|
extra_inc_infos = inc_pattern.findall(content)
|
||||||
|
assert(len(extra_inc_infos) <= 1)
|
||||||
|
print(extra_inc_infos)
|
||||||
|
if len(extra_inc_infos) == 1:
|
||||||
|
extra_inc_body, _ = extra_inc_infos[0]
|
||||||
|
|
||||||
visitor["ctor_and_member"] = body
|
visitor["ctor_and_member"] = body
|
||||||
|
visitor["extra_inc"] = extra_inc_body
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
query_path = "../../internal/core/src/query/"
|
query_path = "../../internal/core/src/query/"
|
||||||
|
|||||||
@ -9,7 +9,9 @@
|
|||||||
#pragma once
|
#pragma once
|
||||||
// Generated File
|
// Generated File
|
||||||
// DO NOT EDIT
|
// DO NOT EDIT
|
||||||
|
@@extra_inc@@
|
||||||
#include "@@base_visitor@@.h"
|
#include "@@base_visitor@@.h"
|
||||||
|
|
||||||
namespace @@namespace@@ {
|
namespace @@namespace@@ {
|
||||||
class @@visitor_name@@ : @@base_visitor@@ {
|
class @@visitor_name@@ : @@base_visitor@@ {
|
||||||
public:
|
public:
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user