diff --git a/core/src/server/delivery/RequestHandler.cpp b/core/src/server/delivery/RequestHandler.cpp index 3b46fad5f6..a15f3dbaa4 100644 --- a/core/src/server/delivery/RequestHandler.cpp +++ b/core/src/server/delivery/RequestHandler.cpp @@ -40,6 +40,7 @@ #include "server/delivery/request/ShowPartitionsRequest.h" #include "server/delivery/hybrid_request/CreateHybridCollectionRequest.h" +#include "server/delivery/hybrid_request/DescribeHybridCollectionRequest.h" #include "server/delivery/hybrid_request/HybridSearchRequest.h" #include "server/delivery/hybrid_request/InsertEntityRequest.h" @@ -266,6 +267,15 @@ RequestHandler::CreateHybridCollection(const std::shared_ptr& context, return request_ptr->status(); } +Status +RequestHandler::DescribeHybridCollection(const std::shared_ptr& context, const std::string& collection_name, + std::unordered_map& field_types) { + BaseRequestPtr request_ptr = DescribeHybridCollectionRequest::Create(context, collection_name, field_types); + + RequestScheduler::ExecRequest(request_ptr); + return request_ptr->status(); +} + Status RequestHandler::HasHybridCollection(const std::shared_ptr& context, std::string& collection_name, bool& has_collection) { diff --git a/core/src/server/delivery/RequestHandler.h b/core/src/server/delivery/RequestHandler.h index c52a8e250b..32ca15078b 100644 --- a/core/src/server/delivery/RequestHandler.h +++ b/core/src/server/delivery/RequestHandler.h @@ -121,6 +121,10 @@ class RequestHandler { std::vector>& vector_dimensions, std::vector>& field_extra_params); + Status + DescribeHybridCollection(const std::shared_ptr& context, const std::string& collection_name, + std::unordered_map& field_types); + Status HasHybridCollection(const std::shared_ptr& context, std::string& collection_name, bool& has_collection); diff --git a/core/src/server/delivery/request/BaseRequest.cpp b/core/src/server/delivery/request/BaseRequest.cpp index 66e7f3c68c..8427c3165e 100644 --- a/core/src/server/delivery/request/BaseRequest.cpp +++ b/core/src/server/delivery/request/BaseRequest.cpp @@ -49,6 +49,7 @@ RequestGroup(BaseRequest::RequestType type) { {BaseRequest::kDropCollection, DDL_DML_REQUEST_GROUP}, {BaseRequest::kPreloadCollection, DQL_REQUEST_GROUP}, {BaseRequest::kCreateHybridCollection, DDL_DML_REQUEST_GROUP}, + {BaseRequest::kDescribeHybridCollection, INFO_REQUEST_GROUP}, // partition operations {BaseRequest::kCreatePartition, DDL_DML_REQUEST_GROUP}, diff --git a/core/src/server/grpc_impl/GrpcRequestHandler.cpp b/core/src/server/grpc_impl/GrpcRequestHandler.cpp index 5bd04a72b0..84bdc4c592 100644 --- a/core/src/server/grpc_impl/GrpcRequestHandler.cpp +++ b/core/src/server/grpc_impl/GrpcRequestHandler.cpp @@ -799,6 +799,13 @@ GrpcRequestHandler::CreateHybridCollection(::grpc::ServerContext* context, const return ::grpc::Status::OK; } +::grpc::Status +GrpcRequestHandler::DescribeHybridCollection(::grpc::ServerContext* context, + const ::milvus::grpc::CollectionName* request, + ::milvus::grpc::Mapping* response) { + CHECK_NULLPTR_RETURN(request); +} + ::grpc::Status GrpcRequestHandler::InsertEntity(::grpc::ServerContext* context, const ::milvus::grpc::HInsertParam* request, ::milvus::grpc::HEntityIDs* response) { @@ -916,7 +923,6 @@ GrpcRequestHandler::HybridSearch(::grpc::ServerContext* context, const ::milvus: DeSerialization(request->general_query(), boolean_query); query::GeneralQueryPtr general_query = std::make_shared(); - general_query->bin = std::make_shared(); query::GenBinaryQuery(boolean_query, general_query->bin); Status status; diff --git a/core/src/server/grpc_impl/GrpcRequestHandler.h b/core/src/server/grpc_impl/GrpcRequestHandler.h index 4f62fe8857..cd7906dfa5 100644 --- a/core/src/server/grpc_impl/GrpcRequestHandler.h +++ b/core/src/server/grpc_impl/GrpcRequestHandler.h @@ -320,10 +320,9 @@ class GrpcRequestHandler final : public ::milvus::grpc::MilvusService::Service, // const ::milvus::grpc::CollectionName* request, // ::milvus::grpc::Status* response) override; // - // ::grpc::Status - // DescribeCollection(::grpc::ServerContext* context, - // const ::milvus::grpc::CollectionName* request, - // ::milvus::grpc::Mapping* response) override; + ::grpc::Status + DescribeHybridCollection(::grpc::ServerContext* context, const ::milvus::grpc::CollectionName* request, + ::milvus::grpc::Mapping* response) override; // // ::grpc::Status // CountCollection(::grpc::ServerContext* context, diff --git a/core/src/server/web_impl/controller/WebController.hpp b/core/src/server/web_impl/controller/WebController.hpp index ded3a3d97e..a4a2520dd9 100644 --- a/core/src/server/web_impl/controller/WebController.hpp +++ b/core/src/server/web_impl/controller/WebController.hpp @@ -11,13 +11,13 @@ #pragma once -#include #include +#include -#include -#include #include #include +#include +#include #include "utils/Log.h" #include "utils/TimeRecorder.h" @@ -37,11 +37,12 @@ namespace web { class WebController : public oatpp::web::server::api::ApiController { public: WebController(const std::shared_ptr& objectMapper) - : oatpp::web::server::api::ApiController(objectMapper) {} + : oatpp::web::server::api::ApiController(objectMapper) { + } public: - static std::shared_ptr createShared( - OATPP_COMPONENT(std::shared_ptr, objectMapper)) { + static std::shared_ptr + createShared(OATPP_COMPONENT(std::shared_ptr, objectMapper)) { return std::make_shared(objectMapper); } @@ -84,8 +85,8 @@ class WebController : public oatpp::web::server::api::ApiController { response = createDtoResponse(Status::CODE_400, status_dto); } - tr.ElapseFromBegin("Done. Status: code = " + std::to_string(status_dto->code->getValue()) - + ", reason = " + status_dto->message->std_str() + ". Total cost"); + tr.ElapseFromBegin("Done. Status: code = " + std::to_string(status_dto->code->getValue()) + + ", reason = " + status_dto->message->std_str() + ". Total cost"); return response; } @@ -115,8 +116,8 @@ class WebController : public oatpp::web::server::api::ApiController { response = createDtoResponse(Status::CODE_400, status_dto); } - tr.ElapseFromBegin("Done. Status: code = " + std::to_string(status_dto->code->getValue()) - + ", reason = " + status_dto->message->std_str() + ". Total cost"); + tr.ElapseFromBegin("Done. Status: code = " + std::to_string(status_dto->code->getValue()) + + ", reason = " + status_dto->message->std_str() + ". Total cost"); return response; } @@ -139,8 +140,8 @@ class WebController : public oatpp::web::server::api::ApiController { response = createDtoResponse(Status::CODE_400, status_dto); } - tr.ElapseFromBegin("Done. Status: code = " + std::to_string(status_dto->code->getValue()) - + ", reason = " + status_dto->message->std_str() + ". Total cost"); + tr.ElapseFromBegin("Done. Status: code = " + std::to_string(status_dto->code->getValue()) + + ", reason = " + status_dto->message->std_str() + ". Total cost"); return response; } @@ -172,8 +173,8 @@ class WebController : public oatpp::web::server::api::ApiController { response = createDtoResponse(Status::CODE_400, status_dto); } - std::string ttr = "Done. Status: code = " + std::to_string(status_dto->code->getValue()) - + ", reason = " + status_dto->message->std_str() + ". Total cost"; + std::string ttr = "Done. Status: code = " + std::to_string(status_dto->code->getValue()) + + ", reason = " + status_dto->message->std_str() + ". Total cost"; tr.ElapseFromBegin(ttr); return response; @@ -197,8 +198,8 @@ class WebController : public oatpp::web::server::api::ApiController { response = createDtoResponse(Status::CODE_400, status_dto); } - std::string ttr = "Done. Status: code = " + std::to_string(status_dto->code->getValue()) - + ", reason = " + status_dto->message->std_str() + ". Total cost"; + std::string ttr = "Done. Status: code = " + std::to_string(status_dto->code->getValue()) + + ", reason = " + status_dto->message->std_str() + ". Total cost"; tr.ElapseFromBegin(ttr); return response; } @@ -229,8 +230,8 @@ class WebController : public oatpp::web::server::api::ApiController { response = createDtoResponse(Status::CODE_400, status_dto); } - std::string ttr = "Done. Status: code = " + std::to_string(status_dto->code->getValue()) - + ", reason = " + status_dto->message->std_str() + ". Total cost"; + std::string ttr = "Done. Status: code = " + std::to_string(status_dto->code->getValue()) + + ", reason = " + status_dto->message->std_str() + ". Total cost"; tr.ElapseFromBegin(ttr); return response; } @@ -243,7 +244,6 @@ class WebController : public oatpp::web::server::api::ApiController { WebRequestHandler handler = WebRequestHandler(); - String result; auto status_dto = handler.ShowTables(query_params, result); std::shared_ptr response; @@ -255,8 +255,8 @@ class WebController : public oatpp::web::server::api::ApiController { response = createDtoResponse(Status::CODE_400, status_dto); } - std::string ttr = "Done. Status: code = " + std::to_string(status_dto->code->getValue()) - + ", reason = " + status_dto->message->std_str() + ". Total cost"; + std::string ttr = "Done. Status: code = " + std::to_string(status_dto->code->getValue()) + + ", reason = " + status_dto->message->std_str() + ". Total cost"; tr.ElapseFromBegin(ttr); return response; @@ -270,8 +270,8 @@ class WebController : public oatpp::web::server::api::ApiController { ADD_CORS(GetTable) - ENDPOINT("GET", "/collections/{collection_name}", GetTable, - PATH(String, collection_name), QUERIES(const QueryParams&, query_params)) { + ENDPOINT("GET", "/collections/{collection_name}", GetTable, PATH(String, collection_name), + QUERIES(const QueryParams&, query_params)) { TimeRecorder tr(std::string(WEB_LOG_PREFIX) + "GET \'/collections/" + collection_name->std_str() + "\'"); tr.RecordSection("Received request."); @@ -292,8 +292,8 @@ class WebController : public oatpp::web::server::api::ApiController { response = createDtoResponse(Status::CODE_400, status_dto); } - std::string ttr = "Done. Status: code = " + std::to_string(status_dto->code->getValue()) - + ", reason = " + status_dto->message->std_str() + ". Total cost"; + std::string ttr = "Done. Status: code = " + std::to_string(status_dto->code->getValue()) + + ", reason = " + status_dto->message->std_str() + ". Total cost"; tr.ElapseFromBegin(ttr); return response; @@ -320,8 +320,8 @@ class WebController : public oatpp::web::server::api::ApiController { response = createDtoResponse(Status::CODE_400, status_dto); } - std::string ttr = "Done. Status: code = " + std::to_string(status_dto->code->getValue()) - + ", reason = " + status_dto->message->std_str() + ". Total cost"; + std::string ttr = "Done. Status: code = " + std::to_string(status_dto->code->getValue()) + + ", reason = " + status_dto->message->std_str() + ". Total cost"; tr.ElapseFromBegin(ttr); return response; @@ -335,8 +335,8 @@ class WebController : public oatpp::web::server::api::ApiController { ADD_CORS(CreateIndex) - ENDPOINT("POST", "/collections/{collection_name}/indexes", CreateIndex, - PATH(String, collection_name), BODY_STRING(String, body)) { + ENDPOINT("POST", "/collections/{collection_name}/indexes", CreateIndex, PATH(String, collection_name), + BODY_STRING(String, body)) { TimeRecorder tr(std::string(WEB_LOG_PREFIX) + "POST \'/tables/" + collection_name->std_str() + "/indexes\'"); tr.RecordSection("Received request."); @@ -355,8 +355,8 @@ class WebController : public oatpp::web::server::api::ApiController { response = createDtoResponse(Status::CODE_400, status_dto); } - std::string ttr = "Done. Status: code = " + std::to_string(status_dto->code->getValue()) - + ", reason = " + status_dto->message->std_str() + ". Total cost"; + std::string ttr = "Done. Status: code = " + std::to_string(status_dto->code->getValue()) + + ", reason = " + status_dto->message->std_str() + ". Total cost"; tr.ElapseFromBegin(ttr); return response; @@ -365,7 +365,8 @@ class WebController : public oatpp::web::server::api::ApiController { ADD_CORS(GetIndex) ENDPOINT("GET", "/collections/{collection_name}/indexes", GetIndex, PATH(String, collection_name)) { - TimeRecorder tr(std::string(WEB_LOG_PREFIX) + "GET \'/collections/" + collection_name->std_str() + "/indexes\'"); + TimeRecorder tr(std::string(WEB_LOG_PREFIX) + "GET \'/collections/" + collection_name->std_str() + + "/indexes\'"); tr.RecordSection("Received request."); auto handler = WebRequestHandler(); @@ -385,8 +386,8 @@ class WebController : public oatpp::web::server::api::ApiController { response = createDtoResponse(Status::CODE_400, status_dto); } - std::string ttr = "Done. Status: code = " + std::to_string(status_dto->code->getValue()) - + ", reason = " + status_dto->message->std_str() + ". Total cost"; + std::string ttr = "Done. Status: code = " + std::to_string(status_dto->code->getValue()) + + ", reason = " + status_dto->message->std_str() + ". Total cost"; tr.ElapseFromBegin(ttr); return response; @@ -395,7 +396,8 @@ class WebController : public oatpp::web::server::api::ApiController { ADD_CORS(DropIndex) ENDPOINT("DELETE", "/collections/{collection_name}/indexes", DropIndex, PATH(String, collection_name)) { - TimeRecorder tr(std::string(WEB_LOG_PREFIX) + "DELETE \'/collections/" + collection_name->std_str() + "/indexes\'"); + TimeRecorder tr(std::string(WEB_LOG_PREFIX) + "DELETE \'/collections/" + collection_name->std_str() + + "/indexes\'"); tr.RecordSection("Received request."); auto handler = WebRequestHandler(); @@ -413,8 +415,8 @@ class WebController : public oatpp::web::server::api::ApiController { response = createDtoResponse(Status::CODE_400, status_dto); } - std::string ttr = "Done. Status: code = " + std::to_string(status_dto->code->getValue()) - + ", reason = " + status_dto->message->std_str() + ". Total cost"; + std::string ttr = "Done. Status: code = " + std::to_string(status_dto->code->getValue()) + + ", reason = " + status_dto->message->std_str() + ". Total cost"; tr.ElapseFromBegin(ttr); return response; @@ -428,9 +430,10 @@ class WebController : public oatpp::web::server::api::ApiController { ADD_CORS(CreatePartition) - ENDPOINT("POST", "/collections/{collection_name}/partitions", - CreatePartition, PATH(String, collection_name), BODY_DTO(PartitionRequestDto::ObjectWrapper, body)) { - TimeRecorder tr(std::string(WEB_LOG_PREFIX) + "POST \'/collections/" + collection_name->std_str() + "/partitions\'"); + ENDPOINT("POST", "/collections/{collection_name}/partitions", CreatePartition, PATH(String, collection_name), + BODY_DTO(PartitionRequestDto::ObjectWrapper, body)) { + TimeRecorder tr(std::string(WEB_LOG_PREFIX) + "POST \'/collections/" + collection_name->std_str() + + "/partitions\'"); tr.RecordSection("Received request."); auto handler = WebRequestHandler(); @@ -448,17 +451,18 @@ class WebController : public oatpp::web::server::api::ApiController { response = createDtoResponse(Status::CODE_400, status_dto); } - tr.ElapseFromBegin("Done. Status: code = " + std::to_string(status_dto->code->getValue()) - + ", reason = " + status_dto->message->std_str() + ". Total cost"); + tr.ElapseFromBegin("Done. Status: code = " + std::to_string(status_dto->code->getValue()) + + ", reason = " + status_dto->message->std_str() + ". Total cost"); return response; } ADD_CORS(ShowPartitions) - ENDPOINT("GET", "/collections/{collection_name}/partitions", ShowPartitions, - PATH(String, collection_name), QUERIES(const QueryParams&, query_params)) { - TimeRecorder tr(std::string(WEB_LOG_PREFIX) + "GET \'/collections/" + collection_name->std_str() + "/partitions\'"); + ENDPOINT("GET", "/collections/{collection_name}/partitions", ShowPartitions, PATH(String, collection_name), + QUERIES(const QueryParams&, query_params)) { + TimeRecorder tr(std::string(WEB_LOG_PREFIX) + "GET \'/collections/" + collection_name->std_str() + + "/partitions\'"); tr.RecordSection("Received request."); auto offset = query_params.get("offset"); @@ -476,21 +480,22 @@ class WebController : public oatpp::web::server::api::ApiController { case StatusCode::COLLECTION_NOT_EXISTS: response = createDtoResponse(Status::CODE_404, status_dto); break; - default:response = createDtoResponse(Status::CODE_400, status_dto); + default: + response = createDtoResponse(Status::CODE_400, status_dto); } - tr.ElapseFromBegin("Done. Status: code = " + std::to_string(status_dto->code->getValue()) - + ", reason = " + status_dto->message->std_str() + ". Total cost"); + tr.ElapseFromBegin("Done. Status: code = " + std::to_string(status_dto->code->getValue()) + + ", reason = " + status_dto->message->std_str() + ". Total cost"); return response; } ADD_CORS(DropPartition) - ENDPOINT("DELETE", "/collections/{collection_name}/partitions", DropPartition, - PATH(String, collection_name), BODY_STRING(String, body)) { - TimeRecorder tr(std::string(WEB_LOG_PREFIX) + - "DELETE \'/collections/" + collection_name->std_str() + "/partitions\'"); + ENDPOINT("DELETE", "/collections/{collection_name}/partitions", DropPartition, PATH(String, collection_name), + BODY_STRING(String, body)) { + TimeRecorder tr(std::string(WEB_LOG_PREFIX) + "DELETE \'/collections/" + collection_name->std_str() + + "/partitions\'"); tr.RecordSection("Received request."); auto handler = WebRequestHandler(); @@ -508,16 +513,16 @@ class WebController : public oatpp::web::server::api::ApiController { response = createDtoResponse(Status::CODE_400, status_dto); } - tr.ElapseFromBegin("Done. Status: code = " + std::to_string(status_dto->code->getValue()) - + ", reason = " + status_dto->message->std_str() + ". Total cost"); + tr.ElapseFromBegin("Done. Status: code = " + std::to_string(status_dto->code->getValue()) + + ", reason = " + status_dto->message->std_str() + ". Total cost"); return response; } ADD_CORS(ShowSegments) - ENDPOINT("GET", "/collections/{collection_name}/segments", ShowSegments, - PATH(String, collection_name), QUERIES(const QueryParams&, query_params)) { + ENDPOINT("GET", "/collections/{collection_name}/segments", ShowSegments, PATH(String, collection_name), + QUERIES(const QueryParams&, query_params)) { auto offset = query_params.get("offset"); auto page_size = query_params.get("page_size"); @@ -541,7 +546,8 @@ class WebController : public oatpp::web::server::api::ApiController { * GetSegmentVector */ ENDPOINT("GET", "/collections/{collection_name}/segments/{segment_name}/{info}", GetSegmentInfo, - PATH(String, collection_name), PATH(String, segment_name), PATH(String, info), QUERIES(const QueryParams&, query_params)) { + PATH(String, collection_name), PATH(String, segment_name), PATH(String, info), + QUERIES(const QueryParams&, query_params)) { auto offset = query_params.get("offset"); auto page_size = query_params.get("page_size"); @@ -570,8 +576,8 @@ class WebController : public oatpp::web::server::api::ApiController { * * GetVectorByID ?id= */ - ENDPOINT("GET", "/collections/{collection_name}/vectors", GetVectors, - PATH(String, collection_name), QUERIES(const QueryParams&, query_params)) { + ENDPOINT("GET", "/collections/{collection_name}/vectors", GetVectors, PATH(String, collection_name), + QUERIES(const QueryParams&, query_params)) { auto handler = WebRequestHandler(); String response; auto status_dto = handler.GetVector(collection_name, query_params, response); @@ -588,9 +594,10 @@ class WebController : public oatpp::web::server::api::ApiController { ADD_CORS(Insert) - ENDPOINT("POST", "/collections/{collection_name}/vectors", Insert, - PATH(String, collection_name), BODY_STRING(String, body)) { - TimeRecorder tr(std::string(WEB_LOG_PREFIX) + "POST \'/collections/" + collection_name->std_str() + "/vectors\'"); + ENDPOINT("POST", "/collections/{collection_name}/vectors", Insert, PATH(String, collection_name), + BODY_STRING(String, body)) { + TimeRecorder tr(std::string(WEB_LOG_PREFIX) + "POST \'/collections/" + collection_name->std_str() + + "/vectors\'"); tr.RecordSection("Received request."); auto ids_dto = VectorIdsDto::createShared(); @@ -609,17 +616,48 @@ class WebController : public oatpp::web::server::api::ApiController { response = createDtoResponse(Status::CODE_400, status_dto); } - tr.ElapseFromBegin("Done. Status: code = " + std::to_string(status_dto->code->getValue()) - + ", reason = " + status_dto->message->std_str() + ". Total cost"); + tr.ElapseFromBegin("Done. Status: code = " + std::to_string(status_dto->code->getValue()) + + ", reason = " + status_dto->message->std_str() + ". Total cost"); + + return response; + } + + ADD_CORS(InsertEntity) + + ENDPOINT("POST", "/hybrid_collections/{collection_name}/entities", InsertEntity, PATH(String, collection_name), + BODY_STRING(String, body)) { + TimeRecorder tr(std::string(WEB_LOG_PREFIX) + "POST \'/hybrid_collections/" + collection_name->std_str() + + "/entities\'"); + tr.RecordSection("Received request."); + + auto ids_dto = VectorIdsDto::createShared(); + WebRequestHandler handler = WebRequestHandler(); + + std::shared_ptr response; + auto status_dto = handler.InsertEntity(collection_name, body, ids_dto); + switch (status_dto->code->getValue()) { + case StatusCode::SUCCESS: + response = createDtoResponse(Status::CODE_201, ids_dto); + break; + case StatusCode::COLLECTION_NOT_EXISTS: + response = createDtoResponse(Status::CODE_404, status_dto); + break; + default: + response = createDtoResponse(Status::CODE_400, status_dto); + } + + tr.ElapseFromBegin("Done. Status: code = " + std::to_string(status_dto->code->getValue()) + + ", reason = " + status_dto->message->std_str() + ". Total cost"); return response; } ADD_CORS(VectorsOp) - ENDPOINT("PUT", "/collections/{collection_name}/vectors", VectorsOp, - PATH(String, collection_name), BODY_STRING(String, body)) { - TimeRecorder tr(std::string(WEB_LOG_PREFIX) + "PUT \'/collections/" + collection_name->std_str() + "/vectors\'"); + ENDPOINT("PUT", "/collections/{collection_name}/vectors", VectorsOp, PATH(String, collection_name), + BODY_STRING(String, body)) { + TimeRecorder tr(std::string(WEB_LOG_PREFIX) + "PUT \'/collections/" + collection_name->std_str() + + "/vectors\'"); tr.RecordSection("Received request."); WebRequestHandler handler = WebRequestHandler(); @@ -638,8 +676,8 @@ class WebController : public oatpp::web::server::api::ApiController { response = createDtoResponse(Status::CODE_400, status_dto); } - tr.ElapseFromBegin("Done. Status: code = " + std::to_string(status_dto->code->getValue()) - + ", reason = " + status_dto->message->std_str() + ". Total cost"); + tr.ElapseFromBegin("Done. Status: code = " + std::to_string(status_dto->code->getValue()) + + ", reason = " + status_dto->message->std_str() + ". Total cost"); return response; } @@ -668,8 +706,8 @@ class WebController : public oatpp::web::server::api::ApiController { response = createDtoResponse(Status::CODE_400, status_dto); } - tr.ElapseFromBegin("Done. Status: code = " + std::to_string(status_dto->code->getValue()) - + ", reason = " + status_dto->message->std_str() + ". Total cost"); + tr.ElapseFromBegin("Done. Status: code = " + std::to_string(status_dto->code->getValue()) + + ", reason = " + status_dto->message->std_str() + ". Total cost"); return response; } @@ -694,19 +732,41 @@ class WebController : public oatpp::web::server::api::ApiController { default: response = createDtoResponse(Status::CODE_400, status_dto); } - tr.ElapseFromBegin("Done. Status: code = " + std::to_string(status_dto->code->getValue()) - + ", reason = " + status_dto->message->std_str() + ". Total cost"); + tr.ElapseFromBegin("Done. Status: code = " + std::to_string(status_dto->code->getValue()) + + ", reason = " + status_dto->message->std_str() + ". Total cost"); return response; } + ADD_CORS(CreateHybridCollection) + + ENDPOINT("POST", "/hybrid_collections", CreateHybridCollection, BODY_STRING(String, body_str)) { + TimeRecorder tr(std::string(WEB_LOG_PREFIX) + "POST \'/hybrid_collections\'"); + tr.RecordSection("Received request."); + WebRequestHandler handler = WebRequestHandler(); + + std::shared_ptr response; + auto status_dto = handler.CreateHybridCollection(body_str); + switch (status_dto->code->getValue()) { + case StatusCode::SUCCESS: + response = createDtoResponse(Status::CODE_201, status_dto); + break; + default: + response = createDtoResponse(Status::CODE_400, status_dto); + } + + std::string ttr = "Done. Status: code = " + std::to_string(status_dto->code->getValue()) + + ", reason = " + status_dto->message->std_str() + ". Total cost"; + tr.ElapseFromBegin(ttr); + return response; + } + /** * Finish ENDPOINTs generation ('ApiController' codegen) */ #include OATPP_CODEGEN_END(ApiController) - }; -} // namespace web -} // namespace server -} // namespace milvus +} // namespace web +} // namespace server +} // namespace milvus diff --git a/core/src/server/web_impl/handler/WebRequestHandler.cpp b/core/src/server/web_impl/handler/WebRequestHandler.cpp index 27c81b557e..6748cd66f7 100644 --- a/core/src/server/web_impl/handler/WebRequestHandler.cpp +++ b/core/src/server/web_impl/handler/WebRequestHandler.cpp @@ -15,6 +15,7 @@ #include #include #include +#include #include #include "config/Config.h" @@ -567,6 +568,251 @@ WebRequestHandler::Search(const std::string& collection_name, const nlohmann::js return Status::OK(); } +Status +WebRequestHandler::ProcessLeafQueryJson(const nlohmann::json& json, milvus::query::BooleanQueryPtr& query) { + if (json.contains("term")) { + auto leaf_query = std::make_shared(); + auto term_json = json["term"]; + std::string field_name = term_json["field_name"]; + auto term_value_json = term_json["values"]; + if (!term_value_json.is_array()) { + std::string msg = "Term json string is not an array"; + return Status{BODY_PARSE_FAIL, msg}; + } + + auto term_size = term_value_json.size(); + auto term_query = std::make_shared(); + term_query->field_name = field_name; + term_query->field_value.resize(term_size * sizeof(int64_t)); + + switch (field_type_.at(field_name)) { + case engine::meta::hybrid::DataType::INT8: + case engine::meta::hybrid::DataType::INT16: + case engine::meta::hybrid::DataType::INT32: + case engine::meta::hybrid::DataType::INT64: { + std::vector term_value(term_size, 0); + for (uint64_t i = 0; i < term_size; ++i) { + term_value[i] = term_value_json[i].get(); + } + memcpy(term_query->field_value.data(), term_value.data(), term_size * sizeof(int64_t)); + break; + } + case engine::meta::hybrid::DataType::FLOAT: + case engine::meta::hybrid::DataType::DOUBLE: { + std::vector term_value(term_size, 0); + for (uint64_t i = 0; i < term_size; ++i) { + term_value[i] = term_value_json[i].get(); + } + memcpy(term_query->field_value.data(), term_value.data(), term_size * sizeof(double)); + break; + } + } + + leaf_query->term_query = term_query; + query->AddLeafQuery(leaf_query); + } else if (json.contains("range")) { + auto leaf_query = std::make_shared(); + auto range_query = std::make_shared(); + + auto range_json = json["range"]; + std::string field_name = range_json["field_name"]; + range_query->field_name = field_name; + + auto range_value_json = range_json["values"]; + if (range_value_json.contains("lt")) { + query::CompareExpr compare_expr; + compare_expr.compare_operator = query::CompareOperator::LT; + compare_expr.operand = range_value_json["lt"].get(); + range_query->compare_expr.emplace_back(compare_expr); + } + if (range_value_json.contains("lte")) { + query::CompareExpr compare_expr; + compare_expr.compare_operator = query::CompareOperator::LTE; + compare_expr.operand = range_value_json["lte"].get(); + range_query->compare_expr.emplace_back(compare_expr); + } + if (range_value_json.contains("eq")) { + query::CompareExpr compare_expr; + compare_expr.compare_operator = query::CompareOperator::EQ; + compare_expr.operand = range_value_json["eq"].get(); + range_query->compare_expr.emplace_back(compare_expr); + } + if (range_value_json.contains("ne")) { + query::CompareExpr compare_expr; + compare_expr.compare_operator = query::CompareOperator::NE; + compare_expr.operand = range_value_json["ne"].get(); + range_query->compare_expr.emplace_back(compare_expr); + } + if (range_value_json.contains("gt")) { + query::CompareExpr compare_expr; + compare_expr.compare_operator = query::CompareOperator::GT; + compare_expr.operand = range_value_json["gt"].get(); + range_query->compare_expr.emplace_back(compare_expr); + } + if (range_value_json.contains("gte")) { + query::CompareExpr compare_expr; + compare_expr.compare_operator = query::CompareOperator::GTE; + compare_expr.operand = range_value_json["gte"].get(); + range_query->compare_expr.emplace_back(compare_expr); + } + + leaf_query->range_query = range_query; + query->AddLeafQuery(leaf_query); + } else if (json.contains("vector")) { + auto leaf_query = std::make_shared(); + auto vector_query = std::make_shared(); + + auto vector_json = json["vector"]; + std::string field_name = vector_json["field_name"]; + vector_query->field_name = field_name; + + engine::VectorsData vectors; + // TODO(yukun): process binary vector + CopyRecordsFromJson(vector_json["values"], vectors, false); + + vector_query->query_vector.float_data = vectors.float_data_; + vector_query->query_vector.binary_data = vectors.binary_data_; + + vector_query->topk = vector_json["topk"].get(); + vector_query->extra_params = vector_json["extra_params"]; + + leaf_query->vector_query = vector_query; + query->AddLeafQuery(leaf_query); + } + return Status::OK(); +} + +Status +WebRequestHandler::ProcessBoolQueryJson(const nlohmann::json& query_json, query::BooleanQueryPtr& boolean_query) { + if (query_json.contains("must")) { + boolean_query->SetOccur(query::Occur::MUST); + auto must_json = query_json["must"]; + if (!must_json.is_array()) { + std::string msg = "Must json string is not an array"; + return Status{BODY_PARSE_FAIL, msg}; + } + + for (auto& json : must_json) { + auto must_query = std::make_shared(); + if (json.contains("must") || json.contains("should") || json.contains("must_not")) { + ProcessBoolQueryJson(json, must_query); + boolean_query->AddBooleanQuery(must_query); + } else { + ProcessLeafQueryJson(json, boolean_query); + } + } + return Status::OK(); + } else if (query_json.contains("should")) { + boolean_query->SetOccur(query::Occur::SHOULD); + auto should_json = query_json["should"]; + if (!should_json.is_array()) { + std::string msg = "Should json string is not an array"; + return Status{BODY_PARSE_FAIL, msg}; + } + + for (auto& json : should_json) { + if (json.contains("must") || json.contains("should") || json.contains("must_not")) { + auto should_query = std::make_shared(); + ProcessBoolQueryJson(json, should_query); + boolean_query->AddBooleanQuery(should_query); + } else { + ProcessLeafQueryJson(json, boolean_query); + } + } + return Status::OK(); + } else if (query_json.contains("must_not")) { + boolean_query->SetOccur(query::Occur::MUST_NOT); + auto should_json = query_json["must_not"]; + if (!should_json.is_array()) { + std::string msg = "Must_not json string is not an array"; + return Status{BODY_PARSE_FAIL, msg}; + } + + for (auto& json : should_json) { + if (json.contains("must") || json.contains("should") || json.contains("must_not")) { + auto must_not_query = std::make_shared(); + ProcessBoolQueryJson(json, must_not_query); + boolean_query->AddBooleanQuery(must_not_query); + } else { + ProcessLeafQueryJson(json, boolean_query); + } + } + return Status::OK(); + } else { + std::string msg = "Must json string doesnot include right query"; + return Status{BODY_PARSE_FAIL, msg}; + } +} + +Status +WebRequestHandler::HybridSearch(const std::string& collection_name, const nlohmann::json& json, + std::string& result_str) { + Status status; + + status = request_handler_.DescribeHybridCollection(context_ptr_, collection_name, field_type_); + if (!status.ok()) { + return Status{UNEXPECTED_ERROR, "DescribeHybridCollection failed"}; + } + + std::vector partition_tags; + if (json.contains("partition_tags")) { + auto tags = json["partition_tags"]; + if (!tags.is_null() && !tags.is_array()) { + return Status(BODY_PARSE_FAIL, "Field \"partition_tags\" must be a array"); + } + + for (auto& tag : tags) { + partition_tags.emplace_back(tag.get()); + } + } + + if (json.contains("bool")) { + auto boolean_query_json = json["bool"]; + query::BooleanQueryPtr boolean_query = std::make_shared(); + + status = ProcessBoolQueryJson(boolean_query_json, boolean_query); + if (!status.ok()) { + return status; + } + query::GeneralQueryPtr general_query = std::make_shared(); + query::GenBinaryQuery(boolean_query, general_query->bin); + + context::HybridSearchContextPtr hybrid_search_context = std::make_shared(); + TopKQueryResult result; + status = request_handler_.HybridSearch(context_ptr_, hybrid_search_context, collection_name, partition_tags, + general_query, result); + + if (!status.ok()) { + return status; + } + + nlohmann::json result_json; + result_json["num"] = result.row_num_; + if (result.row_num_ == 0) { + result_json["result"] = std::vector(); + result_str = result_json.dump(); + return Status::OK(); + } + + auto step = result.id_list_.size() / result.row_num_; + nlohmann::json search_result_json; + for (size_t i = 0; i < result.row_num_; i++) { + nlohmann::json raw_result_json; + for (size_t j = 0; j < step; j++) { + nlohmann::json one_result_json; + one_result_json["id"] = std::to_string(result.id_list_.at(i * step + j)); + one_result_json["distance"] = std::to_string(result.distance_list_.at(i * step + j)); + raw_result_json.emplace_back(one_result_json); + } + search_result_json.emplace_back(raw_result_json); + } + result_json["result"] = search_result_json; + result_str = result_json.dump(); + } + + return Status::OK(); +} + Status WebRequestHandler::DeleteByIDs(const std::string& collection_name, const nlohmann::json& json, std::string& result_str) { @@ -930,6 +1176,50 @@ WebRequestHandler::CreateTable(const TableRequestDto::ObjectWrapper& collection_ ASSIGN_RETURN_STATUS_DTO(status) } +StatusDto::ObjectWrapper +WebRequestHandler::CreateHybridCollection(const milvus::server::web::OString& body) { + auto json_str = nlohmann::json::parse(body->c_str()); + std::string collection_name = json_str["collection_name"]; + + // TODO(yukun): do checking + std::vector> field_types; + std::vector> field_extra_params; + std::vector> vector_dimensions; + for (auto& field : json_str["fields"]) { + std::string field_name = field["field_name"]; + std::string field_type = field["field_type"]; + auto extra_params = field["extra_params"]; + if (field_type == "int8") { + field_types.emplace_back(std::make_pair(field_name, engine::meta::hybrid::DataType::INT8)); + } else if (field_type == "int16") { + field_types.emplace_back(std::make_pair(field_name, engine::meta::hybrid::DataType::INT16)); + } else if (field_type == "int32") { + field_types.emplace_back(std::make_pair(field_name, engine::meta::hybrid::DataType::INT32)); + } else if (field_type == "int64") { + field_types.emplace_back(std::make_pair(field_name, engine::meta::hybrid::DataType::INT64)); + } else if (field_type == "float") { + field_types.emplace_back(std::make_pair(field_name, engine::meta::hybrid::DataType::FLOAT)); + } else if (field_type == "double") { + field_types.emplace_back(std::make_pair(field_name, engine::meta::hybrid::DataType::DOUBLE)); + } else if (field_type == "vector") { + } else { + std::string msg = field_name + " has wrong field_type"; + RETURN_STATUS_DTO(BODY_PARSE_FAIL, msg.c_str()); + } + + field_extra_params.emplace_back(std::make_pair(field_name, extra_params.dump())); + + if (extra_params.contains("dimension")) { + vector_dimensions.emplace_back(std::make_pair(field_name, extra_params["dimension"].get())); + } + } + + auto status = request_handler_.CreateHybridCollection(context_ptr_, collection_name, field_types, vector_dimensions, + field_extra_params); + + ASSIGN_RETURN_STATUS_DTO(status) +} + StatusDto::ObjectWrapper WebRequestHandler::ShowTables(const OQueryParams& query_params, OString& result) { int64_t offset = 0; @@ -1347,6 +1637,106 @@ WebRequestHandler::Insert(const OString& collection_name, const OString& body, V ASSIGN_RETURN_STATUS_DTO(status) } +StatusDto::ObjectWrapper +WebRequestHandler::InsertEntity(const OString& collection_name, const milvus::server::web::OString& body, + VectorIdsDto::ObjectWrapper& ids_dto) { + if (nullptr == body.get() || body->getSize() == 0) { + RETURN_STATUS_DTO(BODY_FIELD_LOSS, "Request payload is required.") + } + + auto body_json = nlohmann::json::parse(body->c_str()); + std::string partition_tag = body_json["partition_tag"]; + + uint64_t row_num = body_json["row_num"]; + + std::unordered_map field_types; + auto status = request_handler_.DescribeHybridCollection(context_ptr_, collection_name->c_str(), field_types); + + auto entities = body_json["entity"]; + if (!entities.is_array()) { + RETURN_STATUS_DTO(ILLEGAL_BODY, "An entity must be an array"); + } + + std::vector field_names; + std::vector> attr_values; + size_t attr_size = 0; + std::unordered_map vector_datas; + for (auto& entity : entities) { + std::string field_name = entity["field_name"]; + field_names.emplace_back(field_name); + auto field_value = entity["field_value"]; + std::vector attr_value; + switch (field_types.at(field_name)) { + case engine::meta::hybrid::DataType::INT8: + case engine::meta::hybrid::DataType::INT16: + case engine::meta::hybrid::DataType::INT32: + case engine::meta::hybrid::DataType::INT64: { + std::vector value; + auto size = field_value.size(); + value.resize(size); + attr_value.resize(size * sizeof(int64_t)); + size_t offset = 0; + for (auto data : field_value) { + value[offset] = data.get(); + ++offset; + } + memcpy(attr_value.data(), value.data(), size * sizeof(int64_t)); + attr_size += size * sizeof(int64_t); + attr_values.emplace_back(attr_value); + break; + } + case engine::meta::hybrid::DataType::FLOAT: + case engine::meta::hybrid::DataType::DOUBLE: { + std::vector value; + auto size = field_value.size(); + value.resize(size); + attr_value.resize(size * sizeof(double)); + size_t offset = 0; + for (auto data : field_value) { + value[offset] = data.get(); + ++offset; + } + memcpy(attr_value.data(), value.data(), size * sizeof(double)); + attr_size += size * sizeof(double); + + attr_values.emplace_back(attr_value); + break; + } + case engine::meta::hybrid::DataType::VECTOR: { + bool bin_flag; + status = IsBinaryTable(collection_name->c_str(), bin_flag); + if (!status.ok()) { + ASSIGN_RETURN_STATUS_DTO(status) + } + + engine::VectorsData vectors; + CopyRecordsFromJson(field_value, vectors, bin_flag); + vector_datas.insert(std::make_pair(field_name, vectors)); + } + default: {} + } + } + + std::vector attrs(attr_size, 0); + size_t attr_offset = 0; + for (auto& data : attr_values) { + memcpy(attrs.data() + attr_offset, data.data(), data.size()); + attr_offset += data.size(); + } + + status = request_handler_.InsertEntity(context_ptr_, collection_name->c_str(), partition_tag, row_num, field_names, + attrs, vector_datas); + + if (status.ok()) { + ids_dto->ids = ids_dto->ids->createShared(); + for (auto& id : vector_datas.begin()->second.id_array_) { + ids_dto->ids->pushBack(std::to_string(id).c_str()); + } + } + + ASSIGN_RETURN_STATUS_DTO(status) +} + StatusDto::ObjectWrapper WebRequestHandler::GetVector(const OString& collection_name, const OQueryParams& query_params, OString& response) { int64_t id = 0; @@ -1389,6 +1779,8 @@ WebRequestHandler::VectorsOp(const OString& collection_name, const OString& payl status = DeleteByIDs(collection_name->std_str(), payload_json["delete"], result_str); } else if (payload_json.contains("search")) { status = Search(collection_name->std_str(), payload_json["search"], result_str); + } else if (payload_json.contains("query")) { + status = HybridSearch(collection_name->c_str(), payload_json["query"], result_str); } else { status = Status(ILLEGAL_BODY, "Unknown body"); } diff --git a/core/src/server/web_impl/handler/WebRequestHandler.h b/core/src/server/web_impl/handler/WebRequestHandler.h index aeef4b9e12..c2f7a8f9bd 100644 --- a/core/src/server/web_impl/handler/WebRequestHandler.h +++ b/core/src/server/web_impl/handler/WebRequestHandler.h @@ -14,6 +14,7 @@ #include #include #include +#include #include #include @@ -141,6 +142,15 @@ class WebRequestHandler { Status Search(const std::string& collection_name, const nlohmann::json& json, std::string& result_str); + Status + ProcessLeafQueryJson(const nlohmann::json& json, query::BooleanQueryPtr& boolean_query); + + Status + ProcessBoolQueryJson(const nlohmann::json& query_json, query::BooleanQueryPtr& boolean_query); + + Status + HybridSearch(const std::string& collection_name, const nlohmann::json& json, std::string& result_str); + Status DeleteByIDs(const std::string& collection_name, const nlohmann::json& json, std::string& result_str); @@ -176,6 +186,9 @@ class WebRequestHandler { StatusDto::ObjectWrapper ShowTables(const OQueryParams& query_params, OString& result); + StatusDto::ObjectWrapper + CreateHybridCollection(const OString& body); + StatusDto::ObjectWrapper GetTable(const OString& collection_name, const OQueryParams& query_params, OString& result); @@ -219,6 +232,9 @@ class WebRequestHandler { StatusDto::ObjectWrapper Insert(const OString& collection_name, const OString& body, VectorIdsDto::ObjectWrapper& ids_dto); + StatusDto::ObjectWrapper + InsertEntity(const OString& collection_name, const OString& body, VectorIdsDto::ObjectWrapper& ids_dto); + StatusDto::ObjectWrapper GetVector(const OString& collection_name, const OQueryParams& query_params, OString& response); @@ -244,6 +260,7 @@ class WebRequestHandler { private: std::shared_ptr context_ptr_; RequestHandler request_handler_; + std::unordered_map field_type_; }; } // namespace web diff --git a/core/unittest/server/test_web.cpp b/core/unittest/server/test_web.cpp index 6906b2f591..e313fb8c6b 100644 --- a/core/unittest/server/test_web.cpp +++ b/core/unittest/server/test_web.cpp @@ -156,6 +156,17 @@ RandomBinRecordsJson(int64_t dim, int64_t num) { return json; } +nlohmann::json +RandomAttrRecordsJson(int64_t row_num) { + nlohmann::json json; + std::default_random_engine e; + std::uniform_int_distribution u(0, 1000); + for (size_t i = 0; i < row_num; i++) { + json.push_back(u(e)); + } + return json; +} + std::string RandomName() { unsigned seed = std::chrono::system_clock::now().time_since_epoch().count(); @@ -697,6 +708,12 @@ class TestClient : public oatpp::web::client::ApiClient { API_CALL("PUT", "/system/{op}", op, PATH(String, cmd_str, "op"), BODY_STRING(String, body)) + API_CALL("POST", "/hybrid_collections", createHybridCollection, BODY_STRING(String, body_str)) + + API_CALL("POST", "/hybrid_collections/{collection_name}/entities", InsertEntity, PATH(String, collection_name), BODY_STRING(String, body)) + +// API_CALL("POST", "/hybrid_collections/{collection_name}/vectors", HybridSearch, PATH(String, collection_name), BODY_STRING(String, body)) + #include OATPP_CODEGEN_END(ApiClient) }; @@ -967,6 +984,92 @@ TEST_F(WebControllerTest, CREATE_COLLECTION) { ASSERT_EQ(OStatus::CODE_400.code, response->getStatusCode()); } +TEST_F(WebControllerTest, HYBRID_TEST) { + nlohmann::json create_json; + create_json["collection_name"] = "test_hybrid"; + nlohmann::json field_json_0, field_json_1; + field_json_0["field_name"] = "field_0"; + field_json_0["field_type"] = "int64"; + field_json_0["extra_params"] = ""; + + field_json_1["field_name"] = "field_1"; + field_json_1["field_type"] = "vector"; + nlohmann::json extra_params; + extra_params["dimension"] = 128; + field_json_1["extra_params"] = extra_params; + + create_json["fields"].push_back(field_json_0); + create_json["fields"].push_back(field_json_1); + + auto response = client_ptr->createHybridCollection(create_json.dump().c_str()); + ASSERT_EQ(OStatus::CODE_201.code, response->getStatusCode()); + auto result_dto = response->readBodyToDto(object_mapper.get()); + ASSERT_EQ(milvus::server::web::StatusCode::SUCCESS, result_dto->code->getValue()) << result_dto->message->std_str(); + + int64_t dimension = 128; + int64_t row_num = 1000; + nlohmann::json insert_json; + insert_json["partition_tag"] = ""; + nlohmann::json entity_0, entity_1; + entity_0["field_name"] = "field_0"; + entity_0["field_value"] = RandomAttrRecordsJson(row_num); + entity_1["field_name"] = "field_1"; + entity_1["field_value"] = RandomRecordsJson(dimension, row_num); + + insert_json["entity"].push_back(entity_0); + insert_json["entity"].push_back(entity_1); + insert_json["row_num"] = row_num; + + OString collection_name = "test_hybrid"; + response = client_ptr->InsertEntity(collection_name, insert_json.dump().c_str(), conncetion_ptr); + ASSERT_EQ(OStatus::CODE_201.code, response->getStatusCode()); + auto vector_dto = response->readBodyToDto(object_mapper.get()); + ASSERT_EQ(row_num, vector_dto->ids->count()); + + auto status = FlushTable(client_ptr, conncetion_ptr, collection_name); + ASSERT_TRUE(status.ok()) << status.message(); + + // TODO(yukun): when hybrid operation is added to wal, the sleep() can be deleted + sleep(2); + + int64_t nq = 10; + int64_t topk = 100; + nlohmann::json query_json, bool_json, term_json, range_json, vector_json; + term_json["term"]["field_name"] = "field_0"; + term_json["term"]["values"] = RandomAttrRecordsJson(nq); + bool_json["must"].push_back(term_json); + + range_json["range"]["field_name"] = "field_0"; + nlohmann::json comp_json; + comp_json["gte"] = "0"; + comp_json["lte"] = "100000"; + range_json["range"]["values"] = comp_json; + bool_json["must"].push_back(range_json); + + vector_json["vector"]["field_name"] = "field_1"; + vector_json["vector"]["topk"] = topk; + vector_json["vector"]["nq"] = nq; + vector_json["vector"]["values"] = RandomRecordsJson(128, nq); + bool_json["must"].push_back(vector_json); + + query_json["query"]["bool"] = bool_json; + + response = client_ptr->vectorsOp(collection_name, query_json.dump().c_str(), conncetion_ptr); + ASSERT_EQ(OStatus::CODE_200.code, response->getStatusCode()); + + auto result_json = nlohmann::json::parse(response->readBodyToString()->std_str()); + ASSERT_TRUE(result_json.contains("num")); + ASSERT_TRUE(result_json["num"].is_number()); + ASSERT_EQ(nq, result_json["num"].get()); + + ASSERT_TRUE(result_json.contains("result")); + ASSERT_TRUE(result_json["result"].is_array()); + + auto result0_json = result_json["result"][0]; + ASSERT_TRUE(result0_json.is_array()); + ASSERT_EQ(topk, result0_json.size()); +} + TEST_F(WebControllerTest, GET_COLLECTION_META) { OString collection_name = "web_test_create_collection" + OString(RandomName().c_str()); GenTable(client_ptr, conncetion_ptr, collection_name, 10, 10, "L2");