mirror of
https://gitee.com/milvus-io/milvus.git
synced 2025-12-31 07:55:38 +08:00
Add http interface for hybrid search (#2079)
* Add http interface for hybrid search Signed-off-by: fishpenguin <kun.yu@zilliz.com> * Add unittest for http hybrid interface Signed-off-by: fishpenguin <kun.yu@zilliz.com> * clang format Signed-off-by: fishpenguin <kun.yu@zilliz.com> * Fix codacy quality Signed-off-by: fishpenguin <kun.yu@zilliz.com>
This commit is contained in:
parent
6d87dc3c72
commit
67d8a9b54c
@ -40,6 +40,7 @@
|
||||
#include "server/delivery/request/ShowPartitionsRequest.h"
|
||||
|
||||
#include "server/delivery/hybrid_request/CreateHybridCollectionRequest.h"
|
||||
#include "server/delivery/hybrid_request/DescribeHybridCollectionRequest.h"
|
||||
#include "server/delivery/hybrid_request/HybridSearchRequest.h"
|
||||
#include "server/delivery/hybrid_request/InsertEntityRequest.h"
|
||||
|
||||
@ -266,6 +267,15 @@ RequestHandler::CreateHybridCollection(const std::shared_ptr<Context>& context,
|
||||
return request_ptr->status();
|
||||
}
|
||||
|
||||
Status
|
||||
RequestHandler::DescribeHybridCollection(const std::shared_ptr<Context>& context, const std::string& collection_name,
|
||||
std::unordered_map<std::string, engine::meta::hybrid::DataType>& field_types) {
|
||||
BaseRequestPtr request_ptr = DescribeHybridCollectionRequest::Create(context, collection_name, field_types);
|
||||
|
||||
RequestScheduler::ExecRequest(request_ptr);
|
||||
return request_ptr->status();
|
||||
}
|
||||
|
||||
Status
|
||||
RequestHandler::HasHybridCollection(const std::shared_ptr<Context>& context, std::string& collection_name,
|
||||
bool& has_collection) {
|
||||
|
||||
@ -121,6 +121,10 @@ class RequestHandler {
|
||||
std::vector<std::pair<std::string, uint64_t>>& vector_dimensions,
|
||||
std::vector<std::pair<std::string, std::string>>& field_extra_params);
|
||||
|
||||
Status
|
||||
DescribeHybridCollection(const std::shared_ptr<Context>& context, const std::string& collection_name,
|
||||
std::unordered_map<std::string, engine::meta::hybrid::DataType>& field_types);
|
||||
|
||||
Status
|
||||
HasHybridCollection(const std::shared_ptr<Context>& context, std::string& collection_name, bool& has_collection);
|
||||
|
||||
|
||||
@ -49,6 +49,7 @@ RequestGroup(BaseRequest::RequestType type) {
|
||||
{BaseRequest::kDropCollection, DDL_DML_REQUEST_GROUP},
|
||||
{BaseRequest::kPreloadCollection, DQL_REQUEST_GROUP},
|
||||
{BaseRequest::kCreateHybridCollection, DDL_DML_REQUEST_GROUP},
|
||||
{BaseRequest::kDescribeHybridCollection, INFO_REQUEST_GROUP},
|
||||
|
||||
// partition operations
|
||||
{BaseRequest::kCreatePartition, DDL_DML_REQUEST_GROUP},
|
||||
|
||||
@ -799,6 +799,13 @@ GrpcRequestHandler::CreateHybridCollection(::grpc::ServerContext* context, const
|
||||
return ::grpc::Status::OK;
|
||||
}
|
||||
|
||||
::grpc::Status
|
||||
GrpcRequestHandler::DescribeHybridCollection(::grpc::ServerContext* context,
|
||||
const ::milvus::grpc::CollectionName* request,
|
||||
::milvus::grpc::Mapping* response) {
|
||||
CHECK_NULLPTR_RETURN(request);
|
||||
}
|
||||
|
||||
::grpc::Status
|
||||
GrpcRequestHandler::InsertEntity(::grpc::ServerContext* context, const ::milvus::grpc::HInsertParam* request,
|
||||
::milvus::grpc::HEntityIDs* response) {
|
||||
@ -916,7 +923,6 @@ GrpcRequestHandler::HybridSearch(::grpc::ServerContext* context, const ::milvus:
|
||||
DeSerialization(request->general_query(), boolean_query);
|
||||
|
||||
query::GeneralQueryPtr general_query = std::make_shared<query::GeneralQuery>();
|
||||
general_query->bin = std::make_shared<query::BinaryQuery>();
|
||||
query::GenBinaryQuery(boolean_query, general_query->bin);
|
||||
|
||||
Status status;
|
||||
|
||||
@ -320,10 +320,9 @@ class GrpcRequestHandler final : public ::milvus::grpc::MilvusService::Service,
|
||||
// const ::milvus::grpc::CollectionName* request,
|
||||
// ::milvus::grpc::Status* response) override;
|
||||
//
|
||||
// ::grpc::Status
|
||||
// DescribeCollection(::grpc::ServerContext* context,
|
||||
// const ::milvus::grpc::CollectionName* request,
|
||||
// ::milvus::grpc::Mapping* response) override;
|
||||
::grpc::Status
|
||||
DescribeHybridCollection(::grpc::ServerContext* context, const ::milvus::grpc::CollectionName* request,
|
||||
::milvus::grpc::Mapping* response) override;
|
||||
//
|
||||
// ::grpc::Status
|
||||
// CountCollection(::grpc::ServerContext* context,
|
||||
|
||||
@ -11,13 +11,13 @@
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <string>
|
||||
#include <iostream>
|
||||
#include <string>
|
||||
|
||||
#include <oatpp/web/server/api/ApiController.hpp>
|
||||
#include <oatpp/parser/json/mapping/ObjectMapper.hpp>
|
||||
#include <oatpp/core/macro/codegen.hpp>
|
||||
#include <oatpp/core/macro/component.hpp>
|
||||
#include <oatpp/parser/json/mapping/ObjectMapper.hpp>
|
||||
#include <oatpp/web/server/api/ApiController.hpp>
|
||||
|
||||
#include "utils/Log.h"
|
||||
#include "utils/TimeRecorder.h"
|
||||
@ -37,11 +37,12 @@ namespace web {
|
||||
class WebController : public oatpp::web::server::api::ApiController {
|
||||
public:
|
||||
WebController(const std::shared_ptr<ObjectMapper>& objectMapper)
|
||||
: oatpp::web::server::api::ApiController(objectMapper) {}
|
||||
: oatpp::web::server::api::ApiController(objectMapper) {
|
||||
}
|
||||
|
||||
public:
|
||||
static std::shared_ptr<WebController> createShared(
|
||||
OATPP_COMPONENT(std::shared_ptr<ObjectMapper>, objectMapper)) {
|
||||
static std::shared_ptr<WebController>
|
||||
createShared(OATPP_COMPONENT(std::shared_ptr<ObjectMapper>, objectMapper)) {
|
||||
return std::make_shared<WebController>(objectMapper);
|
||||
}
|
||||
|
||||
@ -84,8 +85,8 @@ class WebController : public oatpp::web::server::api::ApiController {
|
||||
response = createDtoResponse(Status::CODE_400, status_dto);
|
||||
}
|
||||
|
||||
tr.ElapseFromBegin("Done. Status: code = " + std::to_string(status_dto->code->getValue())
|
||||
+ ", reason = " + status_dto->message->std_str() + ". Total cost");
|
||||
tr.ElapseFromBegin("Done. Status: code = " + std::to_string(status_dto->code->getValue()) +
|
||||
", reason = " + status_dto->message->std_str() + ". Total cost");
|
||||
|
||||
return response;
|
||||
}
|
||||
@ -115,8 +116,8 @@ class WebController : public oatpp::web::server::api::ApiController {
|
||||
response = createDtoResponse(Status::CODE_400, status_dto);
|
||||
}
|
||||
|
||||
tr.ElapseFromBegin("Done. Status: code = " + std::to_string(status_dto->code->getValue())
|
||||
+ ", reason = " + status_dto->message->std_str() + ". Total cost");
|
||||
tr.ElapseFromBegin("Done. Status: code = " + std::to_string(status_dto->code->getValue()) +
|
||||
", reason = " + status_dto->message->std_str() + ". Total cost");
|
||||
|
||||
return response;
|
||||
}
|
||||
@ -139,8 +140,8 @@ class WebController : public oatpp::web::server::api::ApiController {
|
||||
response = createDtoResponse(Status::CODE_400, status_dto);
|
||||
}
|
||||
|
||||
tr.ElapseFromBegin("Done. Status: code = " + std::to_string(status_dto->code->getValue())
|
||||
+ ", reason = " + status_dto->message->std_str() + ". Total cost");
|
||||
tr.ElapseFromBegin("Done. Status: code = " + std::to_string(status_dto->code->getValue()) +
|
||||
", reason = " + status_dto->message->std_str() + ". Total cost");
|
||||
|
||||
return response;
|
||||
}
|
||||
@ -172,8 +173,8 @@ class WebController : public oatpp::web::server::api::ApiController {
|
||||
response = createDtoResponse(Status::CODE_400, status_dto);
|
||||
}
|
||||
|
||||
std::string ttr = "Done. Status: code = " + std::to_string(status_dto->code->getValue())
|
||||
+ ", reason = " + status_dto->message->std_str() + ". Total cost";
|
||||
std::string ttr = "Done. Status: code = " + std::to_string(status_dto->code->getValue()) +
|
||||
", reason = " + status_dto->message->std_str() + ". Total cost";
|
||||
tr.ElapseFromBegin(ttr);
|
||||
|
||||
return response;
|
||||
@ -197,8 +198,8 @@ class WebController : public oatpp::web::server::api::ApiController {
|
||||
response = createDtoResponse(Status::CODE_400, status_dto);
|
||||
}
|
||||
|
||||
std::string ttr = "Done. Status: code = " + std::to_string(status_dto->code->getValue())
|
||||
+ ", reason = " + status_dto->message->std_str() + ". Total cost";
|
||||
std::string ttr = "Done. Status: code = " + std::to_string(status_dto->code->getValue()) +
|
||||
", reason = " + status_dto->message->std_str() + ". Total cost";
|
||||
tr.ElapseFromBegin(ttr);
|
||||
return response;
|
||||
}
|
||||
@ -229,8 +230,8 @@ class WebController : public oatpp::web::server::api::ApiController {
|
||||
response = createDtoResponse(Status::CODE_400, status_dto);
|
||||
}
|
||||
|
||||
std::string ttr = "Done. Status: code = " + std::to_string(status_dto->code->getValue())
|
||||
+ ", reason = " + status_dto->message->std_str() + ". Total cost";
|
||||
std::string ttr = "Done. Status: code = " + std::to_string(status_dto->code->getValue()) +
|
||||
", reason = " + status_dto->message->std_str() + ". Total cost";
|
||||
tr.ElapseFromBegin(ttr);
|
||||
return response;
|
||||
}
|
||||
@ -243,7 +244,6 @@ class WebController : public oatpp::web::server::api::ApiController {
|
||||
|
||||
WebRequestHandler handler = WebRequestHandler();
|
||||
|
||||
|
||||
String result;
|
||||
auto status_dto = handler.ShowTables(query_params, result);
|
||||
std::shared_ptr<OutgoingResponse> response;
|
||||
@ -255,8 +255,8 @@ class WebController : public oatpp::web::server::api::ApiController {
|
||||
response = createDtoResponse(Status::CODE_400, status_dto);
|
||||
}
|
||||
|
||||
std::string ttr = "Done. Status: code = " + std::to_string(status_dto->code->getValue())
|
||||
+ ", reason = " + status_dto->message->std_str() + ". Total cost";
|
||||
std::string ttr = "Done. Status: code = " + std::to_string(status_dto->code->getValue()) +
|
||||
", reason = " + status_dto->message->std_str() + ". Total cost";
|
||||
tr.ElapseFromBegin(ttr);
|
||||
|
||||
return response;
|
||||
@ -270,8 +270,8 @@ class WebController : public oatpp::web::server::api::ApiController {
|
||||
|
||||
ADD_CORS(GetTable)
|
||||
|
||||
ENDPOINT("GET", "/collections/{collection_name}", GetTable,
|
||||
PATH(String, collection_name), QUERIES(const QueryParams&, query_params)) {
|
||||
ENDPOINT("GET", "/collections/{collection_name}", GetTable, PATH(String, collection_name),
|
||||
QUERIES(const QueryParams&, query_params)) {
|
||||
TimeRecorder tr(std::string(WEB_LOG_PREFIX) + "GET \'/collections/" + collection_name->std_str() + "\'");
|
||||
tr.RecordSection("Received request.");
|
||||
|
||||
@ -292,8 +292,8 @@ class WebController : public oatpp::web::server::api::ApiController {
|
||||
response = createDtoResponse(Status::CODE_400, status_dto);
|
||||
}
|
||||
|
||||
std::string ttr = "Done. Status: code = " + std::to_string(status_dto->code->getValue())
|
||||
+ ", reason = " + status_dto->message->std_str() + ". Total cost";
|
||||
std::string ttr = "Done. Status: code = " + std::to_string(status_dto->code->getValue()) +
|
||||
", reason = " + status_dto->message->std_str() + ". Total cost";
|
||||
tr.ElapseFromBegin(ttr);
|
||||
|
||||
return response;
|
||||
@ -320,8 +320,8 @@ class WebController : public oatpp::web::server::api::ApiController {
|
||||
response = createDtoResponse(Status::CODE_400, status_dto);
|
||||
}
|
||||
|
||||
std::string ttr = "Done. Status: code = " + std::to_string(status_dto->code->getValue())
|
||||
+ ", reason = " + status_dto->message->std_str() + ". Total cost";
|
||||
std::string ttr = "Done. Status: code = " + std::to_string(status_dto->code->getValue()) +
|
||||
", reason = " + status_dto->message->std_str() + ". Total cost";
|
||||
tr.ElapseFromBegin(ttr);
|
||||
|
||||
return response;
|
||||
@ -335,8 +335,8 @@ class WebController : public oatpp::web::server::api::ApiController {
|
||||
|
||||
ADD_CORS(CreateIndex)
|
||||
|
||||
ENDPOINT("POST", "/collections/{collection_name}/indexes", CreateIndex,
|
||||
PATH(String, collection_name), BODY_STRING(String, body)) {
|
||||
ENDPOINT("POST", "/collections/{collection_name}/indexes", CreateIndex, PATH(String, collection_name),
|
||||
BODY_STRING(String, body)) {
|
||||
TimeRecorder tr(std::string(WEB_LOG_PREFIX) + "POST \'/tables/" + collection_name->std_str() + "/indexes\'");
|
||||
tr.RecordSection("Received request.");
|
||||
|
||||
@ -355,8 +355,8 @@ class WebController : public oatpp::web::server::api::ApiController {
|
||||
response = createDtoResponse(Status::CODE_400, status_dto);
|
||||
}
|
||||
|
||||
std::string ttr = "Done. Status: code = " + std::to_string(status_dto->code->getValue())
|
||||
+ ", reason = " + status_dto->message->std_str() + ". Total cost";
|
||||
std::string ttr = "Done. Status: code = " + std::to_string(status_dto->code->getValue()) +
|
||||
", reason = " + status_dto->message->std_str() + ". Total cost";
|
||||
tr.ElapseFromBegin(ttr);
|
||||
|
||||
return response;
|
||||
@ -365,7 +365,8 @@ class WebController : public oatpp::web::server::api::ApiController {
|
||||
ADD_CORS(GetIndex)
|
||||
|
||||
ENDPOINT("GET", "/collections/{collection_name}/indexes", GetIndex, PATH(String, collection_name)) {
|
||||
TimeRecorder tr(std::string(WEB_LOG_PREFIX) + "GET \'/collections/" + collection_name->std_str() + "/indexes\'");
|
||||
TimeRecorder tr(std::string(WEB_LOG_PREFIX) + "GET \'/collections/" + collection_name->std_str() +
|
||||
"/indexes\'");
|
||||
tr.RecordSection("Received request.");
|
||||
|
||||
auto handler = WebRequestHandler();
|
||||
@ -385,8 +386,8 @@ class WebController : public oatpp::web::server::api::ApiController {
|
||||
response = createDtoResponse(Status::CODE_400, status_dto);
|
||||
}
|
||||
|
||||
std::string ttr = "Done. Status: code = " + std::to_string(status_dto->code->getValue())
|
||||
+ ", reason = " + status_dto->message->std_str() + ". Total cost";
|
||||
std::string ttr = "Done. Status: code = " + std::to_string(status_dto->code->getValue()) +
|
||||
", reason = " + status_dto->message->std_str() + ". Total cost";
|
||||
tr.ElapseFromBegin(ttr);
|
||||
|
||||
return response;
|
||||
@ -395,7 +396,8 @@ class WebController : public oatpp::web::server::api::ApiController {
|
||||
ADD_CORS(DropIndex)
|
||||
|
||||
ENDPOINT("DELETE", "/collections/{collection_name}/indexes", DropIndex, PATH(String, collection_name)) {
|
||||
TimeRecorder tr(std::string(WEB_LOG_PREFIX) + "DELETE \'/collections/" + collection_name->std_str() + "/indexes\'");
|
||||
TimeRecorder tr(std::string(WEB_LOG_PREFIX) + "DELETE \'/collections/" + collection_name->std_str() +
|
||||
"/indexes\'");
|
||||
tr.RecordSection("Received request.");
|
||||
|
||||
auto handler = WebRequestHandler();
|
||||
@ -413,8 +415,8 @@ class WebController : public oatpp::web::server::api::ApiController {
|
||||
response = createDtoResponse(Status::CODE_400, status_dto);
|
||||
}
|
||||
|
||||
std::string ttr = "Done. Status: code = " + std::to_string(status_dto->code->getValue())
|
||||
+ ", reason = " + status_dto->message->std_str() + ". Total cost";
|
||||
std::string ttr = "Done. Status: code = " + std::to_string(status_dto->code->getValue()) +
|
||||
", reason = " + status_dto->message->std_str() + ". Total cost";
|
||||
tr.ElapseFromBegin(ttr);
|
||||
|
||||
return response;
|
||||
@ -428,9 +430,10 @@ class WebController : public oatpp::web::server::api::ApiController {
|
||||
|
||||
ADD_CORS(CreatePartition)
|
||||
|
||||
ENDPOINT("POST", "/collections/{collection_name}/partitions",
|
||||
CreatePartition, PATH(String, collection_name), BODY_DTO(PartitionRequestDto::ObjectWrapper, body)) {
|
||||
TimeRecorder tr(std::string(WEB_LOG_PREFIX) + "POST \'/collections/" + collection_name->std_str() + "/partitions\'");
|
||||
ENDPOINT("POST", "/collections/{collection_name}/partitions", CreatePartition, PATH(String, collection_name),
|
||||
BODY_DTO(PartitionRequestDto::ObjectWrapper, body)) {
|
||||
TimeRecorder tr(std::string(WEB_LOG_PREFIX) + "POST \'/collections/" + collection_name->std_str() +
|
||||
"/partitions\'");
|
||||
tr.RecordSection("Received request.");
|
||||
|
||||
auto handler = WebRequestHandler();
|
||||
@ -448,17 +451,18 @@ class WebController : public oatpp::web::server::api::ApiController {
|
||||
response = createDtoResponse(Status::CODE_400, status_dto);
|
||||
}
|
||||
|
||||
tr.ElapseFromBegin("Done. Status: code = " + std::to_string(status_dto->code->getValue())
|
||||
+ ", reason = " + status_dto->message->std_str() + ". Total cost");
|
||||
tr.ElapseFromBegin("Done. Status: code = " + std::to_string(status_dto->code->getValue()) +
|
||||
", reason = " + status_dto->message->std_str() + ". Total cost");
|
||||
|
||||
return response;
|
||||
}
|
||||
|
||||
ADD_CORS(ShowPartitions)
|
||||
|
||||
ENDPOINT("GET", "/collections/{collection_name}/partitions", ShowPartitions,
|
||||
PATH(String, collection_name), QUERIES(const QueryParams&, query_params)) {
|
||||
TimeRecorder tr(std::string(WEB_LOG_PREFIX) + "GET \'/collections/" + collection_name->std_str() + "/partitions\'");
|
||||
ENDPOINT("GET", "/collections/{collection_name}/partitions", ShowPartitions, PATH(String, collection_name),
|
||||
QUERIES(const QueryParams&, query_params)) {
|
||||
TimeRecorder tr(std::string(WEB_LOG_PREFIX) + "GET \'/collections/" + collection_name->std_str() +
|
||||
"/partitions\'");
|
||||
tr.RecordSection("Received request.");
|
||||
|
||||
auto offset = query_params.get("offset");
|
||||
@ -476,21 +480,22 @@ class WebController : public oatpp::web::server::api::ApiController {
|
||||
case StatusCode::COLLECTION_NOT_EXISTS:
|
||||
response = createDtoResponse(Status::CODE_404, status_dto);
|
||||
break;
|
||||
default:response = createDtoResponse(Status::CODE_400, status_dto);
|
||||
default:
|
||||
response = createDtoResponse(Status::CODE_400, status_dto);
|
||||
}
|
||||
|
||||
tr.ElapseFromBegin("Done. Status: code = " + std::to_string(status_dto->code->getValue())
|
||||
+ ", reason = " + status_dto->message->std_str() + ". Total cost");
|
||||
tr.ElapseFromBegin("Done. Status: code = " + std::to_string(status_dto->code->getValue()) +
|
||||
", reason = " + status_dto->message->std_str() + ". Total cost");
|
||||
|
||||
return response;
|
||||
}
|
||||
|
||||
ADD_CORS(DropPartition)
|
||||
|
||||
ENDPOINT("DELETE", "/collections/{collection_name}/partitions", DropPartition,
|
||||
PATH(String, collection_name), BODY_STRING(String, body)) {
|
||||
TimeRecorder tr(std::string(WEB_LOG_PREFIX) +
|
||||
"DELETE \'/collections/" + collection_name->std_str() + "/partitions\'");
|
||||
ENDPOINT("DELETE", "/collections/{collection_name}/partitions", DropPartition, PATH(String, collection_name),
|
||||
BODY_STRING(String, body)) {
|
||||
TimeRecorder tr(std::string(WEB_LOG_PREFIX) + "DELETE \'/collections/" + collection_name->std_str() +
|
||||
"/partitions\'");
|
||||
tr.RecordSection("Received request.");
|
||||
|
||||
auto handler = WebRequestHandler();
|
||||
@ -508,16 +513,16 @@ class WebController : public oatpp::web::server::api::ApiController {
|
||||
response = createDtoResponse(Status::CODE_400, status_dto);
|
||||
}
|
||||
|
||||
tr.ElapseFromBegin("Done. Status: code = " + std::to_string(status_dto->code->getValue())
|
||||
+ ", reason = " + status_dto->message->std_str() + ". Total cost");
|
||||
tr.ElapseFromBegin("Done. Status: code = " + std::to_string(status_dto->code->getValue()) +
|
||||
", reason = " + status_dto->message->std_str() + ". Total cost");
|
||||
|
||||
return response;
|
||||
}
|
||||
|
||||
ADD_CORS(ShowSegments)
|
||||
|
||||
ENDPOINT("GET", "/collections/{collection_name}/segments", ShowSegments,
|
||||
PATH(String, collection_name), QUERIES(const QueryParams&, query_params)) {
|
||||
ENDPOINT("GET", "/collections/{collection_name}/segments", ShowSegments, PATH(String, collection_name),
|
||||
QUERIES(const QueryParams&, query_params)) {
|
||||
auto offset = query_params.get("offset");
|
||||
auto page_size = query_params.get("page_size");
|
||||
|
||||
@ -541,7 +546,8 @@ class WebController : public oatpp::web::server::api::ApiController {
|
||||
* GetSegmentVector
|
||||
*/
|
||||
ENDPOINT("GET", "/collections/{collection_name}/segments/{segment_name}/{info}", GetSegmentInfo,
|
||||
PATH(String, collection_name), PATH(String, segment_name), PATH(String, info), QUERIES(const QueryParams&, query_params)) {
|
||||
PATH(String, collection_name), PATH(String, segment_name), PATH(String, info),
|
||||
QUERIES(const QueryParams&, query_params)) {
|
||||
auto offset = query_params.get("offset");
|
||||
auto page_size = query_params.get("page_size");
|
||||
|
||||
@ -570,8 +576,8 @@ class WebController : public oatpp::web::server::api::ApiController {
|
||||
*
|
||||
* GetVectorByID ?id=
|
||||
*/
|
||||
ENDPOINT("GET", "/collections/{collection_name}/vectors", GetVectors,
|
||||
PATH(String, collection_name), QUERIES(const QueryParams&, query_params)) {
|
||||
ENDPOINT("GET", "/collections/{collection_name}/vectors", GetVectors, PATH(String, collection_name),
|
||||
QUERIES(const QueryParams&, query_params)) {
|
||||
auto handler = WebRequestHandler();
|
||||
String response;
|
||||
auto status_dto = handler.GetVector(collection_name, query_params, response);
|
||||
@ -588,9 +594,10 @@ class WebController : public oatpp::web::server::api::ApiController {
|
||||
|
||||
ADD_CORS(Insert)
|
||||
|
||||
ENDPOINT("POST", "/collections/{collection_name}/vectors", Insert,
|
||||
PATH(String, collection_name), BODY_STRING(String, body)) {
|
||||
TimeRecorder tr(std::string(WEB_LOG_PREFIX) + "POST \'/collections/" + collection_name->std_str() + "/vectors\'");
|
||||
ENDPOINT("POST", "/collections/{collection_name}/vectors", Insert, PATH(String, collection_name),
|
||||
BODY_STRING(String, body)) {
|
||||
TimeRecorder tr(std::string(WEB_LOG_PREFIX) + "POST \'/collections/" + collection_name->std_str() +
|
||||
"/vectors\'");
|
||||
tr.RecordSection("Received request.");
|
||||
|
||||
auto ids_dto = VectorIdsDto::createShared();
|
||||
@ -609,17 +616,48 @@ class WebController : public oatpp::web::server::api::ApiController {
|
||||
response = createDtoResponse(Status::CODE_400, status_dto);
|
||||
}
|
||||
|
||||
tr.ElapseFromBegin("Done. Status: code = " + std::to_string(status_dto->code->getValue())
|
||||
+ ", reason = " + status_dto->message->std_str() + ". Total cost");
|
||||
tr.ElapseFromBegin("Done. Status: code = " + std::to_string(status_dto->code->getValue()) +
|
||||
", reason = " + status_dto->message->std_str() + ". Total cost");
|
||||
|
||||
return response;
|
||||
}
|
||||
|
||||
ADD_CORS(InsertEntity)
|
||||
|
||||
ENDPOINT("POST", "/hybrid_collections/{collection_name}/entities", InsertEntity, PATH(String, collection_name),
|
||||
BODY_STRING(String, body)) {
|
||||
TimeRecorder tr(std::string(WEB_LOG_PREFIX) + "POST \'/hybrid_collections/" + collection_name->std_str() +
|
||||
"/entities\'");
|
||||
tr.RecordSection("Received request.");
|
||||
|
||||
auto ids_dto = VectorIdsDto::createShared();
|
||||
WebRequestHandler handler = WebRequestHandler();
|
||||
|
||||
std::shared_ptr<OutgoingResponse> response;
|
||||
auto status_dto = handler.InsertEntity(collection_name, body, ids_dto);
|
||||
switch (status_dto->code->getValue()) {
|
||||
case StatusCode::SUCCESS:
|
||||
response = createDtoResponse(Status::CODE_201, ids_dto);
|
||||
break;
|
||||
case StatusCode::COLLECTION_NOT_EXISTS:
|
||||
response = createDtoResponse(Status::CODE_404, status_dto);
|
||||
break;
|
||||
default:
|
||||
response = createDtoResponse(Status::CODE_400, status_dto);
|
||||
}
|
||||
|
||||
tr.ElapseFromBegin("Done. Status: code = " + std::to_string(status_dto->code->getValue()) +
|
||||
", reason = " + status_dto->message->std_str() + ". Total cost");
|
||||
|
||||
return response;
|
||||
}
|
||||
|
||||
ADD_CORS(VectorsOp)
|
||||
|
||||
ENDPOINT("PUT", "/collections/{collection_name}/vectors", VectorsOp,
|
||||
PATH(String, collection_name), BODY_STRING(String, body)) {
|
||||
TimeRecorder tr(std::string(WEB_LOG_PREFIX) + "PUT \'/collections/" + collection_name->std_str() + "/vectors\'");
|
||||
ENDPOINT("PUT", "/collections/{collection_name}/vectors", VectorsOp, PATH(String, collection_name),
|
||||
BODY_STRING(String, body)) {
|
||||
TimeRecorder tr(std::string(WEB_LOG_PREFIX) + "PUT \'/collections/" + collection_name->std_str() +
|
||||
"/vectors\'");
|
||||
tr.RecordSection("Received request.");
|
||||
|
||||
WebRequestHandler handler = WebRequestHandler();
|
||||
@ -638,8 +676,8 @@ class WebController : public oatpp::web::server::api::ApiController {
|
||||
response = createDtoResponse(Status::CODE_400, status_dto);
|
||||
}
|
||||
|
||||
tr.ElapseFromBegin("Done. Status: code = " + std::to_string(status_dto->code->getValue())
|
||||
+ ", reason = " + status_dto->message->std_str() + ". Total cost");
|
||||
tr.ElapseFromBegin("Done. Status: code = " + std::to_string(status_dto->code->getValue()) +
|
||||
", reason = " + status_dto->message->std_str() + ". Total cost");
|
||||
|
||||
return response;
|
||||
}
|
||||
@ -668,8 +706,8 @@ class WebController : public oatpp::web::server::api::ApiController {
|
||||
response = createDtoResponse(Status::CODE_400, status_dto);
|
||||
}
|
||||
|
||||
tr.ElapseFromBegin("Done. Status: code = " + std::to_string(status_dto->code->getValue())
|
||||
+ ", reason = " + status_dto->message->std_str() + ". Total cost");
|
||||
tr.ElapseFromBegin("Done. Status: code = " + std::to_string(status_dto->code->getValue()) +
|
||||
", reason = " + status_dto->message->std_str() + ". Total cost");
|
||||
|
||||
return response;
|
||||
}
|
||||
@ -694,19 +732,41 @@ class WebController : public oatpp::web::server::api::ApiController {
|
||||
default:
|
||||
response = createDtoResponse(Status::CODE_400, status_dto);
|
||||
}
|
||||
tr.ElapseFromBegin("Done. Status: code = " + std::to_string(status_dto->code->getValue())
|
||||
+ ", reason = " + status_dto->message->std_str() + ". Total cost");
|
||||
tr.ElapseFromBegin("Done. Status: code = " + std::to_string(status_dto->code->getValue()) +
|
||||
", reason = " + status_dto->message->std_str() + ". Total cost");
|
||||
|
||||
return response;
|
||||
}
|
||||
|
||||
ADD_CORS(CreateHybridCollection)
|
||||
|
||||
ENDPOINT("POST", "/hybrid_collections", CreateHybridCollection, BODY_STRING(String, body_str)) {
|
||||
TimeRecorder tr(std::string(WEB_LOG_PREFIX) + "POST \'/hybrid_collections\'");
|
||||
tr.RecordSection("Received request.");
|
||||
WebRequestHandler handler = WebRequestHandler();
|
||||
|
||||
std::shared_ptr<OutgoingResponse> response;
|
||||
auto status_dto = handler.CreateHybridCollection(body_str);
|
||||
switch (status_dto->code->getValue()) {
|
||||
case StatusCode::SUCCESS:
|
||||
response = createDtoResponse(Status::CODE_201, status_dto);
|
||||
break;
|
||||
default:
|
||||
response = createDtoResponse(Status::CODE_400, status_dto);
|
||||
}
|
||||
|
||||
std::string ttr = "Done. Status: code = " + std::to_string(status_dto->code->getValue()) +
|
||||
", reason = " + status_dto->message->std_str() + ". Total cost";
|
||||
tr.ElapseFromBegin(ttr);
|
||||
return response;
|
||||
}
|
||||
|
||||
/**
|
||||
* Finish ENDPOINTs generation ('ApiController' codegen)
|
||||
*/
|
||||
#include OATPP_CODEGEN_END(ApiController)
|
||||
|
||||
};
|
||||
|
||||
} // namespace web
|
||||
} // namespace server
|
||||
} // namespace milvus
|
||||
} // namespace web
|
||||
} // namespace server
|
||||
} // namespace milvus
|
||||
|
||||
@ -15,6 +15,7 @@
|
||||
#include <cmath>
|
||||
#include <ctime>
|
||||
#include <string>
|
||||
#include <unordered_map>
|
||||
#include <vector>
|
||||
|
||||
#include "config/Config.h"
|
||||
@ -567,6 +568,251 @@ WebRequestHandler::Search(const std::string& collection_name, const nlohmann::js
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status
|
||||
WebRequestHandler::ProcessLeafQueryJson(const nlohmann::json& json, milvus::query::BooleanQueryPtr& query) {
|
||||
if (json.contains("term")) {
|
||||
auto leaf_query = std::make_shared<query::LeafQuery>();
|
||||
auto term_json = json["term"];
|
||||
std::string field_name = term_json["field_name"];
|
||||
auto term_value_json = term_json["values"];
|
||||
if (!term_value_json.is_array()) {
|
||||
std::string msg = "Term json string is not an array";
|
||||
return Status{BODY_PARSE_FAIL, msg};
|
||||
}
|
||||
|
||||
auto term_size = term_value_json.size();
|
||||
auto term_query = std::make_shared<query::TermQuery>();
|
||||
term_query->field_name = field_name;
|
||||
term_query->field_value.resize(term_size * sizeof(int64_t));
|
||||
|
||||
switch (field_type_.at(field_name)) {
|
||||
case engine::meta::hybrid::DataType::INT8:
|
||||
case engine::meta::hybrid::DataType::INT16:
|
||||
case engine::meta::hybrid::DataType::INT32:
|
||||
case engine::meta::hybrid::DataType::INT64: {
|
||||
std::vector<int64_t> term_value(term_size, 0);
|
||||
for (uint64_t i = 0; i < term_size; ++i) {
|
||||
term_value[i] = term_value_json[i].get<int64_t>();
|
||||
}
|
||||
memcpy(term_query->field_value.data(), term_value.data(), term_size * sizeof(int64_t));
|
||||
break;
|
||||
}
|
||||
case engine::meta::hybrid::DataType::FLOAT:
|
||||
case engine::meta::hybrid::DataType::DOUBLE: {
|
||||
std::vector<double> term_value(term_size, 0);
|
||||
for (uint64_t i = 0; i < term_size; ++i) {
|
||||
term_value[i] = term_value_json[i].get<double>();
|
||||
}
|
||||
memcpy(term_query->field_value.data(), term_value.data(), term_size * sizeof(double));
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
leaf_query->term_query = term_query;
|
||||
query->AddLeafQuery(leaf_query);
|
||||
} else if (json.contains("range")) {
|
||||
auto leaf_query = std::make_shared<query::LeafQuery>();
|
||||
auto range_query = std::make_shared<query::RangeQuery>();
|
||||
|
||||
auto range_json = json["range"];
|
||||
std::string field_name = range_json["field_name"];
|
||||
range_query->field_name = field_name;
|
||||
|
||||
auto range_value_json = range_json["values"];
|
||||
if (range_value_json.contains("lt")) {
|
||||
query::CompareExpr compare_expr;
|
||||
compare_expr.compare_operator = query::CompareOperator::LT;
|
||||
compare_expr.operand = range_value_json["lt"].get<std::string>();
|
||||
range_query->compare_expr.emplace_back(compare_expr);
|
||||
}
|
||||
if (range_value_json.contains("lte")) {
|
||||
query::CompareExpr compare_expr;
|
||||
compare_expr.compare_operator = query::CompareOperator::LTE;
|
||||
compare_expr.operand = range_value_json["lte"].get<std::string>();
|
||||
range_query->compare_expr.emplace_back(compare_expr);
|
||||
}
|
||||
if (range_value_json.contains("eq")) {
|
||||
query::CompareExpr compare_expr;
|
||||
compare_expr.compare_operator = query::CompareOperator::EQ;
|
||||
compare_expr.operand = range_value_json["eq"].get<std::string>();
|
||||
range_query->compare_expr.emplace_back(compare_expr);
|
||||
}
|
||||
if (range_value_json.contains("ne")) {
|
||||
query::CompareExpr compare_expr;
|
||||
compare_expr.compare_operator = query::CompareOperator::NE;
|
||||
compare_expr.operand = range_value_json["ne"].get<std::string>();
|
||||
range_query->compare_expr.emplace_back(compare_expr);
|
||||
}
|
||||
if (range_value_json.contains("gt")) {
|
||||
query::CompareExpr compare_expr;
|
||||
compare_expr.compare_operator = query::CompareOperator::GT;
|
||||
compare_expr.operand = range_value_json["gt"].get<std::string>();
|
||||
range_query->compare_expr.emplace_back(compare_expr);
|
||||
}
|
||||
if (range_value_json.contains("gte")) {
|
||||
query::CompareExpr compare_expr;
|
||||
compare_expr.compare_operator = query::CompareOperator::GTE;
|
||||
compare_expr.operand = range_value_json["gte"].get<std::string>();
|
||||
range_query->compare_expr.emplace_back(compare_expr);
|
||||
}
|
||||
|
||||
leaf_query->range_query = range_query;
|
||||
query->AddLeafQuery(leaf_query);
|
||||
} else if (json.contains("vector")) {
|
||||
auto leaf_query = std::make_shared<query::LeafQuery>();
|
||||
auto vector_query = std::make_shared<query::VectorQuery>();
|
||||
|
||||
auto vector_json = json["vector"];
|
||||
std::string field_name = vector_json["field_name"];
|
||||
vector_query->field_name = field_name;
|
||||
|
||||
engine::VectorsData vectors;
|
||||
// TODO(yukun): process binary vector
|
||||
CopyRecordsFromJson(vector_json["values"], vectors, false);
|
||||
|
||||
vector_query->query_vector.float_data = vectors.float_data_;
|
||||
vector_query->query_vector.binary_data = vectors.binary_data_;
|
||||
|
||||
vector_query->topk = vector_json["topk"].get<int64_t>();
|
||||
vector_query->extra_params = vector_json["extra_params"];
|
||||
|
||||
leaf_query->vector_query = vector_query;
|
||||
query->AddLeafQuery(leaf_query);
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status
|
||||
WebRequestHandler::ProcessBoolQueryJson(const nlohmann::json& query_json, query::BooleanQueryPtr& boolean_query) {
|
||||
if (query_json.contains("must")) {
|
||||
boolean_query->SetOccur(query::Occur::MUST);
|
||||
auto must_json = query_json["must"];
|
||||
if (!must_json.is_array()) {
|
||||
std::string msg = "Must json string is not an array";
|
||||
return Status{BODY_PARSE_FAIL, msg};
|
||||
}
|
||||
|
||||
for (auto& json : must_json) {
|
||||
auto must_query = std::make_shared<query::BooleanQuery>();
|
||||
if (json.contains("must") || json.contains("should") || json.contains("must_not")) {
|
||||
ProcessBoolQueryJson(json, must_query);
|
||||
boolean_query->AddBooleanQuery(must_query);
|
||||
} else {
|
||||
ProcessLeafQueryJson(json, boolean_query);
|
||||
}
|
||||
}
|
||||
return Status::OK();
|
||||
} else if (query_json.contains("should")) {
|
||||
boolean_query->SetOccur(query::Occur::SHOULD);
|
||||
auto should_json = query_json["should"];
|
||||
if (!should_json.is_array()) {
|
||||
std::string msg = "Should json string is not an array";
|
||||
return Status{BODY_PARSE_FAIL, msg};
|
||||
}
|
||||
|
||||
for (auto& json : should_json) {
|
||||
if (json.contains("must") || json.contains("should") || json.contains("must_not")) {
|
||||
auto should_query = std::make_shared<query::BooleanQuery>();
|
||||
ProcessBoolQueryJson(json, should_query);
|
||||
boolean_query->AddBooleanQuery(should_query);
|
||||
} else {
|
||||
ProcessLeafQueryJson(json, boolean_query);
|
||||
}
|
||||
}
|
||||
return Status::OK();
|
||||
} else if (query_json.contains("must_not")) {
|
||||
boolean_query->SetOccur(query::Occur::MUST_NOT);
|
||||
auto should_json = query_json["must_not"];
|
||||
if (!should_json.is_array()) {
|
||||
std::string msg = "Must_not json string is not an array";
|
||||
return Status{BODY_PARSE_FAIL, msg};
|
||||
}
|
||||
|
||||
for (auto& json : should_json) {
|
||||
if (json.contains("must") || json.contains("should") || json.contains("must_not")) {
|
||||
auto must_not_query = std::make_shared<query::BooleanQuery>();
|
||||
ProcessBoolQueryJson(json, must_not_query);
|
||||
boolean_query->AddBooleanQuery(must_not_query);
|
||||
} else {
|
||||
ProcessLeafQueryJson(json, boolean_query);
|
||||
}
|
||||
}
|
||||
return Status::OK();
|
||||
} else {
|
||||
std::string msg = "Must json string doesnot include right query";
|
||||
return Status{BODY_PARSE_FAIL, msg};
|
||||
}
|
||||
}
|
||||
|
||||
Status
|
||||
WebRequestHandler::HybridSearch(const std::string& collection_name, const nlohmann::json& json,
|
||||
std::string& result_str) {
|
||||
Status status;
|
||||
|
||||
status = request_handler_.DescribeHybridCollection(context_ptr_, collection_name, field_type_);
|
||||
if (!status.ok()) {
|
||||
return Status{UNEXPECTED_ERROR, "DescribeHybridCollection failed"};
|
||||
}
|
||||
|
||||
std::vector<std::string> partition_tags;
|
||||
if (json.contains("partition_tags")) {
|
||||
auto tags = json["partition_tags"];
|
||||
if (!tags.is_null() && !tags.is_array()) {
|
||||
return Status(BODY_PARSE_FAIL, "Field \"partition_tags\" must be a array");
|
||||
}
|
||||
|
||||
for (auto& tag : tags) {
|
||||
partition_tags.emplace_back(tag.get<std::string>());
|
||||
}
|
||||
}
|
||||
|
||||
if (json.contains("bool")) {
|
||||
auto boolean_query_json = json["bool"];
|
||||
query::BooleanQueryPtr boolean_query = std::make_shared<query::BooleanQuery>();
|
||||
|
||||
status = ProcessBoolQueryJson(boolean_query_json, boolean_query);
|
||||
if (!status.ok()) {
|
||||
return status;
|
||||
}
|
||||
query::GeneralQueryPtr general_query = std::make_shared<query::GeneralQuery>();
|
||||
query::GenBinaryQuery(boolean_query, general_query->bin);
|
||||
|
||||
context::HybridSearchContextPtr hybrid_search_context = std::make_shared<context::HybridSearchContext>();
|
||||
TopKQueryResult result;
|
||||
status = request_handler_.HybridSearch(context_ptr_, hybrid_search_context, collection_name, partition_tags,
|
||||
general_query, result);
|
||||
|
||||
if (!status.ok()) {
|
||||
return status;
|
||||
}
|
||||
|
||||
nlohmann::json result_json;
|
||||
result_json["num"] = result.row_num_;
|
||||
if (result.row_num_ == 0) {
|
||||
result_json["result"] = std::vector<int64_t>();
|
||||
result_str = result_json.dump();
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
auto step = result.id_list_.size() / result.row_num_;
|
||||
nlohmann::json search_result_json;
|
||||
for (size_t i = 0; i < result.row_num_; i++) {
|
||||
nlohmann::json raw_result_json;
|
||||
for (size_t j = 0; j < step; j++) {
|
||||
nlohmann::json one_result_json;
|
||||
one_result_json["id"] = std::to_string(result.id_list_.at(i * step + j));
|
||||
one_result_json["distance"] = std::to_string(result.distance_list_.at(i * step + j));
|
||||
raw_result_json.emplace_back(one_result_json);
|
||||
}
|
||||
search_result_json.emplace_back(raw_result_json);
|
||||
}
|
||||
result_json["result"] = search_result_json;
|
||||
result_str = result_json.dump();
|
||||
}
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status
|
||||
WebRequestHandler::DeleteByIDs(const std::string& collection_name, const nlohmann::json& json,
|
||||
std::string& result_str) {
|
||||
@ -930,6 +1176,50 @@ WebRequestHandler::CreateTable(const TableRequestDto::ObjectWrapper& collection_
|
||||
ASSIGN_RETURN_STATUS_DTO(status)
|
||||
}
|
||||
|
||||
StatusDto::ObjectWrapper
|
||||
WebRequestHandler::CreateHybridCollection(const milvus::server::web::OString& body) {
|
||||
auto json_str = nlohmann::json::parse(body->c_str());
|
||||
std::string collection_name = json_str["collection_name"];
|
||||
|
||||
// TODO(yukun): do checking
|
||||
std::vector<std::pair<std::string, engine::meta::hybrid::DataType>> field_types;
|
||||
std::vector<std::pair<std::string, std::string>> field_extra_params;
|
||||
std::vector<std::pair<std::string, uint64_t>> vector_dimensions;
|
||||
for (auto& field : json_str["fields"]) {
|
||||
std::string field_name = field["field_name"];
|
||||
std::string field_type = field["field_type"];
|
||||
auto extra_params = field["extra_params"];
|
||||
if (field_type == "int8") {
|
||||
field_types.emplace_back(std::make_pair(field_name, engine::meta::hybrid::DataType::INT8));
|
||||
} else if (field_type == "int16") {
|
||||
field_types.emplace_back(std::make_pair(field_name, engine::meta::hybrid::DataType::INT16));
|
||||
} else if (field_type == "int32") {
|
||||
field_types.emplace_back(std::make_pair(field_name, engine::meta::hybrid::DataType::INT32));
|
||||
} else if (field_type == "int64") {
|
||||
field_types.emplace_back(std::make_pair(field_name, engine::meta::hybrid::DataType::INT64));
|
||||
} else if (field_type == "float") {
|
||||
field_types.emplace_back(std::make_pair(field_name, engine::meta::hybrid::DataType::FLOAT));
|
||||
} else if (field_type == "double") {
|
||||
field_types.emplace_back(std::make_pair(field_name, engine::meta::hybrid::DataType::DOUBLE));
|
||||
} else if (field_type == "vector") {
|
||||
} else {
|
||||
std::string msg = field_name + " has wrong field_type";
|
||||
RETURN_STATUS_DTO(BODY_PARSE_FAIL, msg.c_str());
|
||||
}
|
||||
|
||||
field_extra_params.emplace_back(std::make_pair(field_name, extra_params.dump()));
|
||||
|
||||
if (extra_params.contains("dimension")) {
|
||||
vector_dimensions.emplace_back(std::make_pair(field_name, extra_params["dimension"].get<uint64_t>()));
|
||||
}
|
||||
}
|
||||
|
||||
auto status = request_handler_.CreateHybridCollection(context_ptr_, collection_name, field_types, vector_dimensions,
|
||||
field_extra_params);
|
||||
|
||||
ASSIGN_RETURN_STATUS_DTO(status)
|
||||
}
|
||||
|
||||
StatusDto::ObjectWrapper
|
||||
WebRequestHandler::ShowTables(const OQueryParams& query_params, OString& result) {
|
||||
int64_t offset = 0;
|
||||
@ -1347,6 +1637,106 @@ WebRequestHandler::Insert(const OString& collection_name, const OString& body, V
|
||||
ASSIGN_RETURN_STATUS_DTO(status)
|
||||
}
|
||||
|
||||
StatusDto::ObjectWrapper
|
||||
WebRequestHandler::InsertEntity(const OString& collection_name, const milvus::server::web::OString& body,
|
||||
VectorIdsDto::ObjectWrapper& ids_dto) {
|
||||
if (nullptr == body.get() || body->getSize() == 0) {
|
||||
RETURN_STATUS_DTO(BODY_FIELD_LOSS, "Request payload is required.")
|
||||
}
|
||||
|
||||
auto body_json = nlohmann::json::parse(body->c_str());
|
||||
std::string partition_tag = body_json["partition_tag"];
|
||||
|
||||
uint64_t row_num = body_json["row_num"];
|
||||
|
||||
std::unordered_map<std::string, engine::meta::hybrid::DataType> field_types;
|
||||
auto status = request_handler_.DescribeHybridCollection(context_ptr_, collection_name->c_str(), field_types);
|
||||
|
||||
auto entities = body_json["entity"];
|
||||
if (!entities.is_array()) {
|
||||
RETURN_STATUS_DTO(ILLEGAL_BODY, "An entity must be an array");
|
||||
}
|
||||
|
||||
std::vector<std::string> field_names;
|
||||
std::vector<std::vector<uint8_t>> attr_values;
|
||||
size_t attr_size = 0;
|
||||
std::unordered_map<std::string, engine::VectorsData> vector_datas;
|
||||
for (auto& entity : entities) {
|
||||
std::string field_name = entity["field_name"];
|
||||
field_names.emplace_back(field_name);
|
||||
auto field_value = entity["field_value"];
|
||||
std::vector<uint8_t> attr_value;
|
||||
switch (field_types.at(field_name)) {
|
||||
case engine::meta::hybrid::DataType::INT8:
|
||||
case engine::meta::hybrid::DataType::INT16:
|
||||
case engine::meta::hybrid::DataType::INT32:
|
||||
case engine::meta::hybrid::DataType::INT64: {
|
||||
std::vector<int64_t> value;
|
||||
auto size = field_value.size();
|
||||
value.resize(size);
|
||||
attr_value.resize(size * sizeof(int64_t));
|
||||
size_t offset = 0;
|
||||
for (auto data : field_value) {
|
||||
value[offset] = data.get<int64_t>();
|
||||
++offset;
|
||||
}
|
||||
memcpy(attr_value.data(), value.data(), size * sizeof(int64_t));
|
||||
attr_size += size * sizeof(int64_t);
|
||||
attr_values.emplace_back(attr_value);
|
||||
break;
|
||||
}
|
||||
case engine::meta::hybrid::DataType::FLOAT:
|
||||
case engine::meta::hybrid::DataType::DOUBLE: {
|
||||
std::vector<double> value;
|
||||
auto size = field_value.size();
|
||||
value.resize(size);
|
||||
attr_value.resize(size * sizeof(double));
|
||||
size_t offset = 0;
|
||||
for (auto data : field_value) {
|
||||
value[offset] = data.get<double>();
|
||||
++offset;
|
||||
}
|
||||
memcpy(attr_value.data(), value.data(), size * sizeof(double));
|
||||
attr_size += size * sizeof(double);
|
||||
|
||||
attr_values.emplace_back(attr_value);
|
||||
break;
|
||||
}
|
||||
case engine::meta::hybrid::DataType::VECTOR: {
|
||||
bool bin_flag;
|
||||
status = IsBinaryTable(collection_name->c_str(), bin_flag);
|
||||
if (!status.ok()) {
|
||||
ASSIGN_RETURN_STATUS_DTO(status)
|
||||
}
|
||||
|
||||
engine::VectorsData vectors;
|
||||
CopyRecordsFromJson(field_value, vectors, bin_flag);
|
||||
vector_datas.insert(std::make_pair(field_name, vectors));
|
||||
}
|
||||
default: {}
|
||||
}
|
||||
}
|
||||
|
||||
std::vector<uint8_t> attrs(attr_size, 0);
|
||||
size_t attr_offset = 0;
|
||||
for (auto& data : attr_values) {
|
||||
memcpy(attrs.data() + attr_offset, data.data(), data.size());
|
||||
attr_offset += data.size();
|
||||
}
|
||||
|
||||
status = request_handler_.InsertEntity(context_ptr_, collection_name->c_str(), partition_tag, row_num, field_names,
|
||||
attrs, vector_datas);
|
||||
|
||||
if (status.ok()) {
|
||||
ids_dto->ids = ids_dto->ids->createShared();
|
||||
for (auto& id : vector_datas.begin()->second.id_array_) {
|
||||
ids_dto->ids->pushBack(std::to_string(id).c_str());
|
||||
}
|
||||
}
|
||||
|
||||
ASSIGN_RETURN_STATUS_DTO(status)
|
||||
}
|
||||
|
||||
StatusDto::ObjectWrapper
|
||||
WebRequestHandler::GetVector(const OString& collection_name, const OQueryParams& query_params, OString& response) {
|
||||
int64_t id = 0;
|
||||
@ -1389,6 +1779,8 @@ WebRequestHandler::VectorsOp(const OString& collection_name, const OString& payl
|
||||
status = DeleteByIDs(collection_name->std_str(), payload_json["delete"], result_str);
|
||||
} else if (payload_json.contains("search")) {
|
||||
status = Search(collection_name->std_str(), payload_json["search"], result_str);
|
||||
} else if (payload_json.contains("query")) {
|
||||
status = HybridSearch(collection_name->c_str(), payload_json["query"], result_str);
|
||||
} else {
|
||||
status = Status(ILLEGAL_BODY, "Unknown body");
|
||||
}
|
||||
|
||||
@ -14,6 +14,7 @@
|
||||
#include <map>
|
||||
#include <memory>
|
||||
#include <string>
|
||||
#include <unordered_map>
|
||||
#include <utility>
|
||||
#include <vector>
|
||||
|
||||
@ -141,6 +142,15 @@ class WebRequestHandler {
|
||||
Status
|
||||
Search(const std::string& collection_name, const nlohmann::json& json, std::string& result_str);
|
||||
|
||||
Status
|
||||
ProcessLeafQueryJson(const nlohmann::json& json, query::BooleanQueryPtr& boolean_query);
|
||||
|
||||
Status
|
||||
ProcessBoolQueryJson(const nlohmann::json& query_json, query::BooleanQueryPtr& boolean_query);
|
||||
|
||||
Status
|
||||
HybridSearch(const std::string& collection_name, const nlohmann::json& json, std::string& result_str);
|
||||
|
||||
Status
|
||||
DeleteByIDs(const std::string& collection_name, const nlohmann::json& json, std::string& result_str);
|
||||
|
||||
@ -176,6 +186,9 @@ class WebRequestHandler {
|
||||
StatusDto::ObjectWrapper
|
||||
ShowTables(const OQueryParams& query_params, OString& result);
|
||||
|
||||
StatusDto::ObjectWrapper
|
||||
CreateHybridCollection(const OString& body);
|
||||
|
||||
StatusDto::ObjectWrapper
|
||||
GetTable(const OString& collection_name, const OQueryParams& query_params, OString& result);
|
||||
|
||||
@ -219,6 +232,9 @@ class WebRequestHandler {
|
||||
StatusDto::ObjectWrapper
|
||||
Insert(const OString& collection_name, const OString& body, VectorIdsDto::ObjectWrapper& ids_dto);
|
||||
|
||||
StatusDto::ObjectWrapper
|
||||
InsertEntity(const OString& collection_name, const OString& body, VectorIdsDto::ObjectWrapper& ids_dto);
|
||||
|
||||
StatusDto::ObjectWrapper
|
||||
GetVector(const OString& collection_name, const OQueryParams& query_params, OString& response);
|
||||
|
||||
@ -244,6 +260,7 @@ class WebRequestHandler {
|
||||
private:
|
||||
std::shared_ptr<Context> context_ptr_;
|
||||
RequestHandler request_handler_;
|
||||
std::unordered_map<std::string, engine::meta::hybrid::DataType> field_type_;
|
||||
};
|
||||
|
||||
} // namespace web
|
||||
|
||||
@ -156,6 +156,17 @@ RandomBinRecordsJson(int64_t dim, int64_t num) {
|
||||
return json;
|
||||
}
|
||||
|
||||
nlohmann::json
|
||||
RandomAttrRecordsJson(int64_t row_num) {
|
||||
nlohmann::json json;
|
||||
std::default_random_engine e;
|
||||
std::uniform_int_distribution<unsigned> u(0, 1000);
|
||||
for (size_t i = 0; i < row_num; i++) {
|
||||
json.push_back(u(e));
|
||||
}
|
||||
return json;
|
||||
}
|
||||
|
||||
std::string
|
||||
RandomName() {
|
||||
unsigned seed = std::chrono::system_clock::now().time_since_epoch().count();
|
||||
@ -697,6 +708,12 @@ class TestClient : public oatpp::web::client::ApiClient {
|
||||
|
||||
API_CALL("PUT", "/system/{op}", op, PATH(String, cmd_str, "op"), BODY_STRING(String, body))
|
||||
|
||||
API_CALL("POST", "/hybrid_collections", createHybridCollection, BODY_STRING(String, body_str))
|
||||
|
||||
API_CALL("POST", "/hybrid_collections/{collection_name}/entities", InsertEntity, PATH(String, collection_name), BODY_STRING(String, body))
|
||||
|
||||
// API_CALL("POST", "/hybrid_collections/{collection_name}/vectors", HybridSearch, PATH(String, collection_name), BODY_STRING(String, body))
|
||||
|
||||
#include OATPP_CODEGEN_END(ApiClient)
|
||||
};
|
||||
|
||||
@ -967,6 +984,92 @@ TEST_F(WebControllerTest, CREATE_COLLECTION) {
|
||||
ASSERT_EQ(OStatus::CODE_400.code, response->getStatusCode());
|
||||
}
|
||||
|
||||
TEST_F(WebControllerTest, HYBRID_TEST) {
|
||||
nlohmann::json create_json;
|
||||
create_json["collection_name"] = "test_hybrid";
|
||||
nlohmann::json field_json_0, field_json_1;
|
||||
field_json_0["field_name"] = "field_0";
|
||||
field_json_0["field_type"] = "int64";
|
||||
field_json_0["extra_params"] = "";
|
||||
|
||||
field_json_1["field_name"] = "field_1";
|
||||
field_json_1["field_type"] = "vector";
|
||||
nlohmann::json extra_params;
|
||||
extra_params["dimension"] = 128;
|
||||
field_json_1["extra_params"] = extra_params;
|
||||
|
||||
create_json["fields"].push_back(field_json_0);
|
||||
create_json["fields"].push_back(field_json_1);
|
||||
|
||||
auto response = client_ptr->createHybridCollection(create_json.dump().c_str());
|
||||
ASSERT_EQ(OStatus::CODE_201.code, response->getStatusCode());
|
||||
auto result_dto = response->readBodyToDto<milvus::server::web::StatusDto>(object_mapper.get());
|
||||
ASSERT_EQ(milvus::server::web::StatusCode::SUCCESS, result_dto->code->getValue()) << result_dto->message->std_str();
|
||||
|
||||
int64_t dimension = 128;
|
||||
int64_t row_num = 1000;
|
||||
nlohmann::json insert_json;
|
||||
insert_json["partition_tag"] = "";
|
||||
nlohmann::json entity_0, entity_1;
|
||||
entity_0["field_name"] = "field_0";
|
||||
entity_0["field_value"] = RandomAttrRecordsJson(row_num);
|
||||
entity_1["field_name"] = "field_1";
|
||||
entity_1["field_value"] = RandomRecordsJson(dimension, row_num);
|
||||
|
||||
insert_json["entity"].push_back(entity_0);
|
||||
insert_json["entity"].push_back(entity_1);
|
||||
insert_json["row_num"] = row_num;
|
||||
|
||||
OString collection_name = "test_hybrid";
|
||||
response = client_ptr->InsertEntity(collection_name, insert_json.dump().c_str(), conncetion_ptr);
|
||||
ASSERT_EQ(OStatus::CODE_201.code, response->getStatusCode());
|
||||
auto vector_dto = response->readBodyToDto<milvus::server::web::VectorIdsDto>(object_mapper.get());
|
||||
ASSERT_EQ(row_num, vector_dto->ids->count());
|
||||
|
||||
auto status = FlushTable(client_ptr, conncetion_ptr, collection_name);
|
||||
ASSERT_TRUE(status.ok()) << status.message();
|
||||
|
||||
// TODO(yukun): when hybrid operation is added to wal, the sleep() can be deleted
|
||||
sleep(2);
|
||||
|
||||
int64_t nq = 10;
|
||||
int64_t topk = 100;
|
||||
nlohmann::json query_json, bool_json, term_json, range_json, vector_json;
|
||||
term_json["term"]["field_name"] = "field_0";
|
||||
term_json["term"]["values"] = RandomAttrRecordsJson(nq);
|
||||
bool_json["must"].push_back(term_json);
|
||||
|
||||
range_json["range"]["field_name"] = "field_0";
|
||||
nlohmann::json comp_json;
|
||||
comp_json["gte"] = "0";
|
||||
comp_json["lte"] = "100000";
|
||||
range_json["range"]["values"] = comp_json;
|
||||
bool_json["must"].push_back(range_json);
|
||||
|
||||
vector_json["vector"]["field_name"] = "field_1";
|
||||
vector_json["vector"]["topk"] = topk;
|
||||
vector_json["vector"]["nq"] = nq;
|
||||
vector_json["vector"]["values"] = RandomRecordsJson(128, nq);
|
||||
bool_json["must"].push_back(vector_json);
|
||||
|
||||
query_json["query"]["bool"] = bool_json;
|
||||
|
||||
response = client_ptr->vectorsOp(collection_name, query_json.dump().c_str(), conncetion_ptr);
|
||||
ASSERT_EQ(OStatus::CODE_200.code, response->getStatusCode());
|
||||
|
||||
auto result_json = nlohmann::json::parse(response->readBodyToString()->std_str());
|
||||
ASSERT_TRUE(result_json.contains("num"));
|
||||
ASSERT_TRUE(result_json["num"].is_number());
|
||||
ASSERT_EQ(nq, result_json["num"].get<int64_t>());
|
||||
|
||||
ASSERT_TRUE(result_json.contains("result"));
|
||||
ASSERT_TRUE(result_json["result"].is_array());
|
||||
|
||||
auto result0_json = result_json["result"][0];
|
||||
ASSERT_TRUE(result0_json.is_array());
|
||||
ASSERT_EQ(topk, result0_json.size());
|
||||
}
|
||||
|
||||
TEST_F(WebControllerTest, GET_COLLECTION_META) {
|
||||
OString collection_name = "web_test_create_collection" + OString(RandomName().c_str());
|
||||
GenTable(client_ptr, conncetion_ptr, collection_name, 10, 10, "L2");
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user