diff --git a/CHANGELOG.md b/CHANGELOG.md index d9acff7d93..24ea875270 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -65,6 +65,7 @@ Please mark all change in change log and use the issue from GitHub - \#2256 k-means clustering algorithm use only Euclidean distance metric - \#2300 Upgrade mishrads configuration to version 0.4 - \#2311 Update mishards methods +- \#2330 Change url for behavior 'get_entities_by_id' ## Task diff --git a/core/src/server/web_impl/controller/WebController.hpp b/core/src/server/web_impl/controller/WebController.hpp index 7c582dfe70..e53d9add39 100644 --- a/core/src/server/web_impl/controller/WebController.hpp +++ b/core/src/server/web_impl/controller/WebController.hpp @@ -577,11 +577,11 @@ class WebController : public oatpp::web::server::api::ApiController { * * GetVectorByID ?id= */ - ENDPOINT("GET", "/collections/{collection_name}/vectors", GetVectors, PATH(String, collection_name), - BODY_STRING(String, body), QUERIES(const QueryParams&, query_params)) { + ENDPOINT("GET", "/collections/{collection_name}/vectors", GetVectors, + PATH(String, collection_name), QUERIES(const QueryParams&, query_params)) { auto handler = WebRequestHandler(); String response; - auto status_dto = handler.GetVector(collection_name, body, query_params, response); + auto status_dto = handler.GetVector(collection_name, query_params, response); switch (status_dto->code->getValue()) { case StatusCode::SUCCESS: diff --git a/core/src/server/web_impl/handler/WebRequestHandler.cpp b/core/src/server/web_impl/handler/WebRequestHandler.cpp index 0b5f536bd6..06a8dde88c 100644 --- a/core/src/server/web_impl/handler/WebRequestHandler.cpp +++ b/core/src/server/web_impl/handler/WebRequestHandler.cpp @@ -1693,22 +1693,20 @@ WebRequestHandler::InsertEntity(const OString& collection_name, const milvus::se } StatusDto::ObjectWrapper -WebRequestHandler::GetVector(const OString& collection_name, const OString& body, const OQueryParams& query_params, - OString& response) { +WebRequestHandler::GetVector(const OString& collection_name, const OQueryParams& query_params, OString& response) { auto status = Status::OK(); try { - auto body_json = nlohmann::json::parse(body->c_str()); - if (!body_json.contains("ids")) { - RETURN_STATUS_DTO(BODY_FIELD_LOSS, "Field \'ids\' is required.") - } - auto ids = body_json["ids"]; - if (!ids.is_array()) { - RETURN_STATUS_DTO(BODY_PARSE_FAIL, "Field \'ids\' must be a array.") + auto query_ids = query_params.get("ids"); + if (query_ids == nullptr || query_ids.get() == nullptr) { + RETURN_STATUS_DTO(QUERY_PARAM_LOSS, "Query param ids is required."); } + std::vector ids; + StringHelpFunctions::SplitStringByDelimeter(query_ids->c_str(), ",", ids); + std::vector vector_ids; for (auto& id : ids) { - vector_ids.push_back(std::stol(id.get())); + vector_ids.push_back(std::stol(id)); } engine::VectorsData vectors; nlohmann::json vectors_json; diff --git a/core/src/server/web_impl/handler/WebRequestHandler.h b/core/src/server/web_impl/handler/WebRequestHandler.h index 216c4825ad..57f37ded6e 100644 --- a/core/src/server/web_impl/handler/WebRequestHandler.h +++ b/core/src/server/web_impl/handler/WebRequestHandler.h @@ -220,7 +220,7 @@ class WebRequestHandler { InsertEntity(const OString& collection_name, const OString& body, VectorIdsDto::ObjectWrapper& ids_dto); StatusDto::ObjectWrapper - GetVector(const OString& collection_name, const OString& body, const OQueryParams& query_params, OString& response); + GetVector(const OString& collection_name, const OQueryParams& query_params, OString& response); StatusDto::ObjectWrapper VectorsOp(const OString& collection_name, const OString& payload, OString& response); diff --git a/core/unittest/server/test_web.cpp b/core/unittest/server/test_web.cpp index 0bbc2f46e7..4b327b4bc8 100644 --- a/core/unittest/server/test_web.cpp +++ b/core/unittest/server/test_web.cpp @@ -33,6 +33,7 @@ #include "server/web_impl/handler/WebRequestHandler.h" #include "src/version.h" #include "utils/CommonUtil.h" +#include "utils/StringHelpFunctions.h" static const char* COLLECTION_NAME = "test_web"; @@ -325,7 +326,7 @@ class TestClient : public oatpp::web::client::ApiClient { PATH(String, collection_name, "collection_name")) API_CALL("GET", "/collections/{collection_name}/vectors", getVectors, - PATH(String, collection_name, "collection_name"), BODY_STRING(String, body)) + PATH(String, collection_name, "collection_name"), QUERY(String, ids)) API_CALL("POST", "/collections/{collection_name}/vectors", insert, PATH(String, collection_name, "collection_name"), BODY_STRING(String, body)) @@ -1302,9 +1303,10 @@ TEST_F(WebControllerTest, GET_VECTORS_BY_IDS) { for (size_t i = 0; i < 10; i++) { vector_ids.emplace_back(ids.at(i)); } - auto body = nlohmann::json(); - body["ids"] = vector_ids; - auto response = client_ptr->getVectors(collection_name, body.dump().c_str(), conncetion_ptr); + + std::string query_ids; + milvus::server::StringHelpFunctions::MergeStringWithDelimeter(vector_ids, ",", query_ids); + auto response = client_ptr->getVectors(collection_name, query_ids.c_str(), conncetion_ptr); ASSERT_EQ(OStatus::CODE_200.code, response->getStatusCode()) << response->readBodyToString()->c_str(); // validate result @@ -1329,7 +1331,7 @@ TEST_F(WebControllerTest, GET_VECTORS_BY_IDS) { ASSERT_EQ(64, vec.size()); // non-existent collection - response = client_ptr->getVectors(collection_name + "_non_existent", body.dump().c_str(), conncetion_ptr); + response = client_ptr->getVectors(collection_name + "_non_existent", query_ids.c_str(), conncetion_ptr); ASSERT_EQ(OStatus::CODE_404.code, response->getStatusCode()) << response->readBodyToString()->c_str(); }