diff --git a/CHANGELOG.md b/CHANGELOG.md index 168a7fe87c..f76bd380ac 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -8,6 +8,7 @@ Please mark all changes in change log and use the issue from GitHub - \#3793 Frequent timeout when running test cases in parallel - \#4030 The configurations name is not effective after update config file and restart docker - \#4059 milvus raise exception when passing a complex dsl during searching +- \#4099 When execute a multiple vectors request, if the result of somerow is not enough, the query data field will be disorder. - \#4113 After milvus 0.11.0 version vector data is misaligned ## Feature diff --git a/core/src/server/grpc_impl/GrpcRequestHandler.cpp b/core/src/server/grpc_impl/GrpcRequestHandler.cpp index 83f6fa4539..e48b5f9903 100644 --- a/core/src/server/grpc_impl/GrpcRequestHandler.cpp +++ b/core/src/server/grpc_impl/GrpcRequestHandler.cpp @@ -1752,10 +1752,7 @@ GrpcRequestHandler::DeserializeJsonToBoolQuery( if (dsl_json.contains("bool")) { auto boolean_query_json = dsl_json["bool"]; JSON_NULL_CHECK(boolean_query_json); - status = ProcessBooleanQueryJson(boolean_query_json, boolean_query, query_ptr); - if (!status.ok()) { - return Status(SERVER_INVALID_DSL_PARAMETER, "DSL does not include bool"); - } + STATUS_CHECK(ProcessBooleanQueryJson(boolean_query_json, boolean_query, query_ptr)); } else { return Status(SERVER_INVALID_DSL_PARAMETER, "DSL does not include bool query"); } diff --git a/sdk/CMakeLists.txt b/sdk/CMakeLists.txt index c42b80ae35..2053ffd54c 100644 --- a/sdk/CMakeLists.txt +++ b/sdk/CMakeLists.txt @@ -43,7 +43,7 @@ message(STATUS "Build type = ${BUILD_TYPE}") unset(CMAKE_EXPORT_COMPILE_COMMANDS CACHE) set(CMAKE_EXPORT_COMPILE_COMMANDS ON) -set(CMAKE_CXX_STANDARD 14) +set(CMAKE_CXX_STANDARD 17) set(CMAKE_CXX_STANDARD_REQUIRED on) if (CMAKE_SYSTEM_PROCESSOR MATCHES "(x86)|(X86)|(amd64)|(AMD64)") diff --git a/sdk/examples/simple/src/ClientTest.cpp b/sdk/examples/simple/src/ClientTest.cpp index 8cd2a45244..2d5554bd78 100644 --- a/sdk/examples/simple/src/ClientTest.cpp +++ b/sdk/examples/simple/src/ClientTest.cpp @@ -23,17 +23,17 @@ namespace { const char* COLLECTION_NAME = milvus_sdk::Utils::GenCollectionName().c_str(); +const char* PARTITION_TAG = "American"; -constexpr int64_t COLLECTION_DIMENSION = 512; +constexpr int64_t COLLECTION_DIMENSION = 8; constexpr milvus::MetricType COLLECTION_METRIC_TYPE = milvus::MetricType::L2; -constexpr int64_t BATCH_ENTITY_COUNT = 10000; -constexpr int64_t NQ = 5; -constexpr int64_t TOP_K = 10; +constexpr int64_t BATCH_ENTITY_COUNT = 3; +constexpr int64_t NQ = 1; +constexpr int64_t TOP_K = 3; constexpr int64_t NPROBE = 16; constexpr int64_t SEARCH_TARGET = BATCH_ENTITY_COUNT / 2; // change this value, result is different -constexpr int64_t ADD_ENTITY_LOOP = 10; +constexpr int64_t ADD_ENTITY_LOOP = 1; constexpr int32_t NLIST = 1024; -const char* PARTITION_TAG = "part"; const char* DIMENSION = "dim"; const char* METRICTYPE = "metric_type"; const char* INDEXTYPE = "index_type"; @@ -77,87 +77,85 @@ ClientTest::ListCollections(std::vector& collection_array) { } void -ClientTest::CreateCollection(const std::string& collection_name) { - milvus::FieldPtr field_ptr1 = std::make_shared(); - milvus::FieldPtr field_ptr2 = std::make_shared(); - milvus::FieldPtr field_ptr4 = std::make_shared(); +ClientTest::CreateCollection() { + milvus::FieldPtr field1 = std::make_shared("release_year", milvus::DataType::INT32, ""); + milvus::FieldPtr field2 = std::make_shared("duration", milvus::DataType::INT32, ""); + nlohmann::json vector_param = {{"dim", COLLECTION_DIMENSION}}; + milvus::FieldPtr field3 = + std::make_shared("embedding", milvus::DataType::VECTOR_FLOAT, vector_param.dump()); - field_ptr1->field_name = "field_1"; - field_ptr1->field_type = milvus::DataType::INT64; - JSON index_param_1; - index_param_1["name"] = "index_1"; - field_ptr1->index_params = index_param_1.dump(); + nlohmann::json json_param; + json_param = {{"auto_id", false}, {"segment_row_limit", 4096}}; + milvus::Mapping mapping = {COLLECTION_NAME, {field1, field2, field3}, json_param.dump()}; - field_ptr2->field_name = "field_2"; - field_ptr2->field_type = milvus::DataType::FLOAT; - JSON index_param_2; - index_param_2["name"] = "index_2"; - field_ptr2->index_params = index_param_2.dump(); - - field_ptr4->field_name = "field_vec"; - field_ptr4->field_type = milvus::DataType::VECTOR_FLOAT; - JSON index_param_4; - index_param_4["name"] = "index_vec"; - field_ptr4->index_params = index_param_4.dump(); - JSON extra_params_4; - extra_params_4[DIMENSION] = COLLECTION_DIMENSION; - field_ptr4->extra_params = extra_params_4.dump(); - - JSON extra_params; - extra_params["segment_row_limit"] = 10000; - extra_params["auto_id"] = false; - milvus::Mapping mapping = {collection_name, {field_ptr1, field_ptr2, field_ptr4}}; - - milvus::Status stat = conn_->CreateCollection(mapping, extra_params.dump()); - std::cout << "CreateCollection function call status: " << stat.message() << std::endl; + milvus::Status status = conn_->CreateCollection(mapping); + std::cout << "CreateCollection function call status: " << status.message() << std::endl; } void -ClientTest::GetCollectionInfo(const std::string& collection_name) { +ClientTest::CreatePartition() { + milvus::PartitionParam param = milvus::PartitionParam{COLLECTION_NAME, PARTITION_TAG}; + auto status = conn_->CreatePartition(param); + std::cout << "CreatePartition function call status: " << status.message() << std::endl; +} + +void +ClientTest::GetCollectionInfo() { milvus::Mapping mapping; - milvus::Status stat = conn_->GetCollectionInfo(collection_name, mapping); + milvus::Status status = conn_->GetCollectionInfo(COLLECTION_NAME, mapping); milvus_sdk::Utils::PrintMapping(mapping); - std::cout << "GetCollectionInfo function call status: " << stat.message() << std::endl; + std::cout << "GetCollectionInfo function call status: " << status.message() << std::endl; } void -ClientTest::InsertEntities(const std::string& collection_name) { - for (int64_t i = 0; i < ADD_ENTITY_LOOP; i++) { - milvus::FieldValue field_value; - std::vector entity_ids; - - int64_t begin_index = i * BATCH_ENTITY_COUNT; - { - milvus_sdk::TimeRecorder rc("Build entities No." + std::to_string(i)); - milvus_sdk::Utils::BuildEntities(begin_index, begin_index + BATCH_ENTITY_COUNT, field_value, entity_ids, - COLLECTION_DIMENSION); - } - milvus::Status status = conn_->Insert(collection_name, "", field_value, entity_ids); - search_id_array_.emplace_back(entity_ids[10]); - std::cout << "InsertEntities function call status: " << status.message() << std::endl; - std::cout << "Returned id array count: " << entity_ids.size() << std::endl; +ClientTest::ListPartitions() { + milvus::PartitionTagList partition_list; + auto status = conn_->ListPartitions(COLLECTION_NAME, partition_list); + std::cout << "Partitions: "; + for (const auto& part : partition_list) { + std::cout << part << std::endl; } + std::cout << "ListPartitions function call status: " << status.message() << std::endl; } void -ClientTest::CountEntities(const std::string& collection_name) { - int64_t entity_count = 0; - auto status = conn_->CountEntities(collection_name, entity_count); - std::cout << "Collection " << collection_name << " entity count: " << entity_count << std::endl; +ClientTest::InsertEntities() { + std::vector duration{208, 226, 252}; + std::vector release_year{2001, 2002, 2003}; + std::vector embedding; + milvus_sdk::Utils::BuildVectors(COLLECTION_DIMENSION, 3, embedding); + + milvus::FieldValue field_value; + std::unordered_map> int32_value = {{"duration", duration}, + {"release_year", release_year}}; + + std::unordered_map> vector_value = {{"embedding", embedding}}; + field_value.int32_value = int32_value; + field_value.vector_value = vector_value; + + std::vector id_array = {1, 2, 3}; + auto status = conn_->Insert(COLLECTION_NAME, PARTITION_TAG, field_value, id_array); + std::cout << "InsertEntities function call status: " << status.message() << std::endl; } void -ClientTest::Flush(const std::string& collection_name) { +ClientTest::CountEntities(int64_t& entity_count) { + auto status = conn_->CountEntities(COLLECTION_NAME, entity_count); + std::cout << "Collection " << COLLECTION_NAME << " entity count: " << entity_count << std::endl; +} + +void +ClientTest::Flush() { milvus_sdk::TimeRecorder rc("Flush"); - std::vector collections = {collection_name}; + std::vector collections = {COLLECTION_NAME}; milvus::Status stat = conn_->Flush(collections); std::cout << "Flush function call status: " << stat.message() << std::endl; } void -ClientTest::GetCollectionStats(const std::string& collection_name) { +ClientTest::GetCollectionStats() { std::string collection_stats; - milvus::Status stat = conn_->GetCollectionStats(collection_name, collection_stats); + milvus::Status stat = conn_->GetCollectionStats(COLLECTION_NAME, collection_stats); std::cout << "Collection stats: " << collection_stats << std::endl; std::cout << "GetCollectionStats function call status: " << stat.message() << std::endl; } @@ -178,76 +176,53 @@ ClientTest::BuildVectors(int64_t nq, int64_t dimension) { } void -ClientTest::GetEntityByID(const std::string& collection_name, const std::vector& id_array) { - std::string result; +ClientTest::GetEntityByID(const std::vector& id_array) { + milvus::Entities entities; { milvus_sdk::TimeRecorder rc("GetEntityByID"); - milvus::Status stat = conn_->GetEntityByID(collection_name, id_array, result); + milvus::Status stat = conn_->GetEntityByID(COLLECTION_NAME, id_array, entities); std::cout << "GetEntityByID function call status: " << stat.message() << std::endl; - } + }; + std::cout << "GetEntityByID function result: " << std::endl; - JSON result_json = JSON::parse(result); - for (const auto& one_result : result_json) { - std::cout << one_result << std::endl; + for (const auto& entity : entities) { + std::cout << "Entity id: " << entity.entity_id << std::endl; + for (const auto& data : entity.scalar_data) { + if (data.first == "duration" || data.first == "release_year") { + std::cout << data.first << ": " << std::any_cast(data.second) << std::endl; + } + } + for (const auto& data : entity.vector_data) { + auto embedding = data.second.float_data; + std::cout << data.first << ":"; + for (const auto& v : embedding) { + std::cout << v << " "; + } + std::cout << std::endl; + } } } void -ClientTest::SearchEntities(const std::string& collection_name, int64_t topk, int64_t nprobe, - const std::string metric_type) { +ClientTest::SearchEntities() { nlohmann::json dsl_json, vector_param_json; - milvus_sdk::Utils::GenDSLJson(dsl_json, vector_param_json, metric_type); + milvus_sdk::Utils::GenDSLJson(dsl_json, vector_param_json, TOP_K, "L2"); - std::vector record_ids; - std::vector temp_entity_array; - for (auto& pair : search_entity_array_) { - temp_entity_array.push_back(pair.second); - } + std::vector query_embedding; + milvus_sdk::Utils::BuildVectors(COLLECTION_DIMENSION, 1, query_embedding); - milvus::VectorParam vector_param = {vector_param_json.dump(), temp_entity_array}; + milvus::VectorParam vector_param = {vector_param_json.dump(), query_embedding}; + + std::vector get_fields{"duration", "release_year", "embedding"}; + nlohmann::json json_params = {{"fields", get_fields}}; std::vector partition_tags; milvus::TopKQueryResult topk_query_result; - auto status = conn_->Search(collection_name, partition_tags, dsl_json.dump(), vector_param, "", topk_query_result); + auto status = conn_->Search(COLLECTION_NAME, partition_tags, dsl_json.dump(), vector_param, json_params.dump(), + topk_query_result); - std::cout << metric_type << " Search function call result: " << std::endl; + std::cout << " Search function call result: " << std::endl; milvus_sdk::Utils::PrintTopKQueryResult(topk_query_result); - std::cout << metric_type << " Search function call status: " << status.message() << std::endl; -} - -void -ClientTest::SearchEntitiesByID(const std::string& collection_name, int64_t topk, int64_t nprobe) { - // std::vector partition_tags; - // milvus::TopKQueryResult topk_query_result; - // - // topk_query_result.clear(); - // - // std::vector id_array; - // for (auto& pair : search_entity_array_) { - // id_array.push_back(pair.first); - // } - // - // std::vector entities; - // milvus::Status stat = conn_->GetEntityByID(collection_name, id_array, entities); - // std::cout << "GetEntityByID function call status: " << stat.message() << std::endl; - // - // JSON json_params = {{"nprobe", nprobe}}; - // milvus_sdk::TimeRecorder rc("Search"); - // stat = conn_->Search(collection_name, partition_tags, entities, topk, json_params.dump(), topk_query_result); - // std::cout << "Search function call status: " << stat.message() << std::endl; - // - // if (topk_query_result.size() != id_array.size()) { - // std::cout << "ERROR! wrong result for query by id" << std::endl; - // return; - // } - // - // for (size_t i = 0; i < id_array.size(); i++) { - // std::cout << "Entity " << id_array[i] << " top " << topk << " search result:" << std::endl; - // const milvus::QueryResult& one_result = topk_query_result[i]; - // for (size_t j = 0; j < one_result.ids.size(); j++) { - // std::cout << "\t" << one_result.ids[j] << "\t" << one_result.distances[j] << std::endl; - // } - // } } void @@ -278,17 +253,16 @@ ClientTest::CompactCollection(const std::string& collection_name) { } void -ClientTest::DeleteByIds(const std::string& collection_name, const std::vector& id_array) { - std::cout << "Delete entity: "; - for (auto id : id_array) { - std::cout << "\t" << id; - } - std::cout << std::endl; +ClientTest::DeleteByIds(const std::vector& id_array) { + auto status = conn_->DeleteEntityByID(COLLECTION_NAME, id_array); + std::cout << "DeleteByID function call status: " << status.message() << std::endl; +} - milvus::Status stat = conn_->DeleteEntityByID(collection_name, id_array); - std::cout << "DeleteByID function call status: " << stat.message() << std::endl; - - Flush(collection_name); +void +ClientTest::DropPartition() { + milvus::PartitionParam param = {COLLECTION_NAME, PARTITION_TAG}; + auto status = conn_->DropPartition(param); + std::cout << "DropPartition function call status: " << status.message() << std::endl; } void @@ -310,43 +284,53 @@ ClientTest::DropCollection(const std::string& collection_name) { void ClientTest::Test() { - std::string collection_name = COLLECTION_NAME; - int64_t dim = COLLECTION_DIMENSION; - milvus::MetricType metric_type = COLLECTION_METRIC_TYPE; + std::vector collection_array; + ListCollections(collection_array); + for (const auto& collection : collection_array) { + DropCollection(collection); + } - std::vector table_array; - ListCollections(table_array); + CreateCollection(); + CreatePartition(); - CreateCollection(collection_name); -// GetCollectionInfo(collection_name); - GetCollectionStats(collection_name); + std::cout << "--------get collection info--------" << std::endl; + GetCollectionInfo(); + std::cout << "\n----------list partitions----------" << std::endl; + ListPartitions(); - ListCollections(table_array); - CountEntities(collection_name); + std::cout << "\n----------insert----------" << std::endl; + InsertEntities(); - InsertEntities(collection_name); - Flush(collection_name); - CountEntities(collection_name); - CreateIndex(collection_name, NLIST); - GetCollectionInfo(collection_name); - // GetCollectionStats(collection_name); - // - LoadCollection(COLLECTION_NAME); - BuildVectors(NQ, COLLECTION_DIMENSION); - // GetEntityByID(collection_name, search_id_array_); - SearchEntities(collection_name, TOP_K, NPROBE, "L2"); - SearchEntities(collection_name, TOP_K, NPROBE, "IP"); - // GetCollectionStats(collection_name); - // - // std::vector delete_ids = {search_id_array_[0], search_id_array_[1]}; - // DeleteByIds(collection_name, delete_ids); - // GetEntityByID(collection_name, search_id_array_); - // CompactCollection(collection_name); - // - // LoadCollection(collection_name); - // SearchEntities(collection_name, TOP_K, NPROBE); // this line get two search error since we delete two - // entities - // - // DropIndex(collection_name, "field_vec", "index_3"); - DropCollection(collection_name); + int64_t before_flush_counts = 0; + int64_t after_flush_counts = 0; + CountEntities(before_flush_counts); + Flush(); + CountEntities(after_flush_counts); + std::cout << "\n----------flush----------" << std::endl; + std::cout << "There are " << before_flush_counts << " films in collection " << COLLECTION_NAME << " before flush" + << std::endl; + std::cout << "There are " << after_flush_counts << " films in collection " << COLLECTION_NAME << " after flush" + << std::endl; + + std::cout << "\n----------get collection stats----------\n"; + GetCollectionStats(); + + std::cout << "\n----------get entity by id = 1, id = 200----------\n"; + std::vector id_array = {1, 200}; + GetEntityByID(id_array); + + std::cout << "\n----------search----------\n"; + SearchEntities(); + + std::vector delete_id_array = {1, 2}; + std::cout << "\n----------delete id = 1, id = 2----------\n"; + DeleteByIds(delete_id_array); + Flush(); + GetEntityByID(delete_id_array); + + int64_t counts_in_collection; + CountEntities(counts_in_collection); + std::cout << "There are " << counts_in_collection << " entities after delete films with 1, 2\n"; + + DropCollection(COLLECTION_NAME); } diff --git a/sdk/examples/simple/src/ClientTest.h b/sdk/examples/simple/src/ClientTest.h index 6e20c057d9..b4e8400014 100644 --- a/sdk/examples/simple/src/ClientTest.h +++ b/sdk/examples/simple/src/ClientTest.h @@ -31,34 +31,37 @@ class ClientTest { ListCollections(std::vector&); void - CreateCollection(const std::string&); + CreateCollection(); void - GetCollectionInfo(const std::string&); + CreatePartition(); void - InsertEntities(const std::string&); + GetCollectionInfo(); void - CountEntities(const std::string&); + ListPartitions(); void - Flush(const std::string&); + InsertEntities(); void - GetCollectionStats(const std::string&); + CountEntities(int64_t& count); + + void + Flush(); + + void + GetCollectionStats(); void BuildVectors(int64_t nq, int64_t dimension); void - GetEntityByID(const std::string&, const std::vector&); + GetEntityByID(const std::vector& id_array); void - SearchEntities(const std::string&, int64_t, int64_t, const std::string metric_type); - - void - SearchEntitiesByID(const std::string&, int64_t, int64_t); + SearchEntities(); void CreateIndex(const std::string&, int64_t); @@ -70,7 +73,10 @@ class ClientTest { CompactCollection(const std::string&); void - DeleteByIds(const std::string&, const std::vector& id_array); + DeleteByIds(const std::vector& id_array); + + void + DropPartition(); void DropIndex(const std::string& collection_name, const std::string& field_name, const std::string& index_name); diff --git a/sdk/examples/utils/Utils.cpp b/sdk/examples/utils/Utils.cpp index cbb19404c7..a132f3f3e4 100644 --- a/sdk/examples/utils/Utils.cpp +++ b/sdk/examples/utils/Utils.cpp @@ -137,10 +137,10 @@ Utils::PrintCollectionParam(const milvus::Mapping& mapping) { BLOCK_SPLITER std::cout << "Collection name: " << mapping.collection_name << std::endl; for (const auto& field : mapping.fields) { - std::cout << "field_name: " << field->field_name; - std::cout << "\tfield_type: " << std::to_string((int)field->field_type); + std::cout << "field_name: " << field->name; + std::cout << "\tfield_type: " << std::to_string((int)field->type); std::cout << "\tindex_param: " << field->index_params; - std::cout << "\textra_param:" << field->extra_params << std::endl; + std::cout << "\textra_param:" << field->params << std::endl; } BLOCK_SPLITER } @@ -167,14 +167,27 @@ Utils::PrintMapping(const milvus::Mapping& mapping) { BLOCK_SPLITER std::cout << "Collection name: " << mapping.collection_name << std::endl; for (const auto& field : mapping.fields) { - std::cout << "field name: " << field->field_name << "\t field type: " << (int32_t)field->field_type - << "\t field index params:" << field->index_params << "\t field extra params: " << field->extra_params + std::cout << "field name: " << field->name << "\t field type: " << (int32_t)field->type + << "\t field index params:" << field->index_params << "\t field extra params: " << field->params << std::endl; } std::cout << "Collection extra params: " << mapping.extra_params << std::endl; BLOCK_SPLITER } +void +Utils::BuildVectors(int64_t dim, int64_t nb, std::vector& vectors) { + std::default_random_engine e; + std::uniform_real_distribution u(0, 1); + vectors.resize(nb); + for (int64_t i = 0; i < nb; i++) { + vectors[i].float_data.resize(dim); + for (int64_t j = 0; j < dim; j++) { + vectors[i].float_data[j] = u(e); + } + } +} + void Utils::BuildEntities(int64_t from, int64_t to, milvus::FieldValue& field_value, std::vector& entity_ids, int64_t dimension) { @@ -336,39 +349,23 @@ Utils::GenLeafQuery() { } void -Utils::GenDSLJson(nlohmann::json& dsl_json, nlohmann::json& vector_param_json, const std::string metric_type) { - uint64_t row_num = 10000; - std::vector term_value; - term_value.resize(row_num); - for (uint64_t i = 0; i < row_num; ++i) { - term_value[i] = i; - } +Utils::GenDSLJson(nlohmann::json& dsl_json, nlohmann::json& vector_param_json, int64_t topk, + const std::string& metric_type) { + std::vector term_vale = {2002, 2003}; + std::vector> embedding; + // dsl_json = {{"bool", + // {"must", + // {{"term", {"release_year", {2002, 2003}}}, + // {{"range", {{"duration", {"GT", 250}}}}}, + // {{"vector", "placeholder_1"}}}}}}; + JSON must_query, term_query, range_query, vector_query; + term_query["term"]["release_year"] = {2002, 2003}; + range_query["range"]["duration"] = {{"GT", 250}}; + vector_query["vector"] = "placeholder_1"; + must_query["must"] = {term_query, range_query, vector_query}; + dsl_json["bool"] = must_query; - nlohmann::json bool_json, term_json, range_json, vector_json; - nlohmann::json term_value_json; - term_value_json["values"] = term_value; - term_json["term"]["field_1"] = term_value_json; - bool_json["must"].push_back(term_json); - - nlohmann::json comp_json; - comp_json["GT"] = 0; - comp_json["LT"] = 100000; - range_json["range"]["field_1"] = comp_json; - bool_json["must"].push_back(range_json); - - std::string placeholder = "placeholder_1"; - vector_json["vector"] = placeholder; - bool_json["must"].push_back(vector_json); - - dsl_json["bool"] = bool_json; - - nlohmann::json query_vector_json, vector_extra_params; - int64_t topk = 10; - query_vector_json["topk"] = topk; - query_vector_json["metric_type"] = metric_type; - vector_extra_params["nprobe"] = 64; - query_vector_json["params"] = vector_extra_params; - vector_param_json[placeholder]["field_vec"] = query_vector_json; + vector_param_json = {{"placeholder_1", {{"embedding", {{"topk", topk}, {"metric_type", metric_type}}}}}}; } void @@ -392,39 +389,27 @@ Utils::GenPureVecDSLJson(nlohmann::json& dsl_json, nlohmann::json& vector_param_ void Utils::PrintTopKQueryResult(milvus::TopKQueryResult& topk_query_result) { for (size_t i = 0; i < topk_query_result.size(); i++) { - auto field_value = topk_query_result[i].field_value; - for (auto& int32_it : field_value.int32_value) { - std::cout << int32_it.first << ":"; - for (auto& data : int32_it.second) { - std::cout << " " << data; - } - std::cout << std::endl; - } - for (auto& int64_it : field_value.int64_value) { - std::cout << int64_it.first << ":"; - for (auto& data : int64_it.second) { - std::cout << " " << data; - } - std::cout << std::endl; - } - for (auto& float_it : field_value.float_value) { - std::cout << float_it.first << ":"; - for (auto& data : float_it.second) { - std::cout << " " << data; - } - std::cout << std::endl; - } - for (auto& double_it : field_value.double_value) { - std::cout << double_it.first << ":"; - for (auto& data : double_it.second) { - std::cout << " " << data; - } - std::cout << std::endl; - } + auto entities = topk_query_result[i].entities; for (size_t j = 0; j < topk_query_result[i].ids.size(); j++) { - std::cout << topk_query_result[i].ids[j] << " --------- " << topk_query_result[i].distances[j] - << std::endl; + std::cout << "- id: " << topk_query_result[i].ids[j] << std::endl; + std::cout << "- distance: " << topk_query_result[i].distances[j] << std::endl; + + for (const auto& data : entities[j].scalar_data) { + if (data.first == "duration" || data.first == "release_year") { + std::cout << "- " << data.first << ": " << std::any_cast(data.second) << std::endl; + } + } + for (const auto& data : entities[j].vector_data) { + if (data.first == "embedding") { + std::cout << "- " << data.first << ": "; + for (const auto& v : data.second.float_data) { + std::cout << v << " "; + } + std::cout << std::endl; + } + } } + std::cout << std::endl; } } diff --git a/sdk/examples/utils/Utils.h b/sdk/examples/utils/Utils.h index 9e51256094..76815f318b 100644 --- a/sdk/examples/utils/Utils.h +++ b/sdk/examples/utils/Utils.h @@ -60,6 +60,9 @@ class Utils { BuildEntities(int64_t from, int64_t to, milvus::FieldValue& field_value, std::vector& entity_ids, int64_t dimension); + static void + BuildVectors(int64_t dim, int64_t nb, std::vector& vectors); + static void PrintSearchResult(const std::vector>& entity_array, const milvus::TopKQueryResult& topk_query_result); @@ -82,7 +85,8 @@ class Utils { GenLeafQuery(); static void - GenDSLJson(nlohmann::json& dsl_json, nlohmann::json& vector_param_json, const std::string metric_type); + GenDSLJson(nlohmann::json& dsl_json, nlohmann::json& vector_param_json, int64_t topk, + const std::string& metric_type); static void GenPureVecDSLJson(nlohmann::json& dsl_json, nlohmann::json& vector_param_json, const std::string metric_type); diff --git a/sdk/grpc/ClientProxy.cpp b/sdk/grpc/ClientProxy.cpp index 0c573bff98..4ecccc9d77 100644 --- a/sdk/grpc/ClientProxy.cpp +++ b/sdk/grpc/ClientProxy.cpp @@ -66,15 +66,21 @@ CopyRowRecord(::milvus::grpc::VectorRowRecord* target, const VectorData& src) { } void -ConstructTopkResult(const ::milvus::grpc::QueryResult& grpc_result, TopKQueryResult& topk_query_result) { - topk_query_result.reserve(grpc_result.row_num()); +ConstructTopkQueryResult(const ::milvus::grpc::QueryResult& grpc_result, TopKQueryResult& topk_query_result) { int64_t nq = grpc_result.row_num(); - int64_t topk = grpc_result.entities().ids_size() / nq; - for (int64_t i = 0; i < nq; i++) { - milvus::QueryResult one_result; + if (nq == 0) { + return; + } + topk_query_result.reserve(nq); + + const auto& grpc_entity = grpc_result.entities(); + int64_t topk = grpc_entity.ids_size() / nq; + int64_t offset = 0; + for (int64_t i = 0; i < grpc_result.row_num(); i++) { + milvus::QueryResult one_result = milvus::QueryResult(); one_result.ids.resize(topk); one_result.distances.resize(topk); - memcpy(one_result.ids.data(), grpc_result.entities().ids().data() + topk * i, topk * sizeof(int64_t)); + memcpy(one_result.ids.data(), grpc_entity.ids().data() + topk * i, topk * sizeof(int64_t)); memcpy(one_result.distances.data(), grpc_result.distances().data() + topk * i, topk * sizeof(float)); int valid_size = one_result.ids.size(); @@ -86,74 +92,52 @@ ConstructTopkResult(const ::milvus::grpc::QueryResult& grpc_result, TopKQueryRes one_result.distances.resize(valid_size); } - topk_query_result.emplace_back(one_result); - } -} - -void -ConstructTopkQueryResult(const ::milvus::grpc::QueryResult& grpc_result, TopKQueryResult& topk_query_result) { - int64_t nq = grpc_result.row_num(); - if (nq == 0) { - return; - } - topk_query_result.reserve(nq); - - const auto& grpc_entity = grpc_result.entities(); - int64_t topk = grpc_entity.ids_size() / nq; - // TODO(yukun): filter -1 results - for (int64_t i = 0; i < grpc_result.row_num(); i++) { - milvus::QueryResult one_result; - one_result.ids.resize(topk); - one_result.distances.resize(topk); - memcpy(one_result.ids.data(), grpc_entity.ids().data() + topk * i, topk * sizeof(int64_t)); - memcpy(one_result.distances.data(), grpc_result.distances().data() + topk * i, topk * sizeof(float)); - int64_t j; - for (j = 0; j < grpc_entity.fields_size(); j++) { - auto grpc_field = grpc_entity.fields(j); - if (grpc_field.has_attr_record()) { - if (grpc_field.attr_record().int32_value_size() > 0) { - std::vector int32_data(topk); - memcpy(int32_data.data(), grpc_field.attr_record().int32_value().data() + topk * i, - topk * sizeof(int32_t)); - - one_result.field_value.int32_value.insert(std::make_pair(grpc_field.field_name(), int32_data)); - } else if (grpc_field.attr_record().int64_value_size() > 0) { - std::vector int64_data(topk); - memcpy(int64_data.data(), grpc_field.attr_record().int64_value().data() + topk * i, - topk * sizeof(int64_t)); - one_result.field_value.int64_value.insert(std::make_pair(grpc_field.field_name(), int64_data)); - } else if (grpc_field.attr_record().float_value_size() > 0) { - std::vector float_data(topk); - memcpy(float_data.data(), grpc_field.attr_record().float_value().data() + topk * i, - topk * sizeof(float)); - one_result.field_value.float_value.insert(std::make_pair(grpc_field.field_name(), float_data)); - } else if (grpc_field.attr_record().double_value_size() > 0) { - std::vector double_data(topk); - memcpy(double_data.data(), grpc_field.attr_record().double_value().data() + topk * i, - topk * sizeof(double)); - one_result.field_value.double_value.insert(std::make_pair(grpc_field.field_name(), double_data)); - } - } - if (grpc_field.has_vector_record()) { - int64_t vector_row_count = grpc_field.vector_record().records_size(); - if (vector_row_count > 0) { - std::vector vector_data(topk); - for (int64_t k = topk * i; k < topk * (i + 1); k++) { - auto grpc_vector_data = grpc_field.vector_record().records(k); - if (grpc_vector_data.float_data_size() > 0) { - vector_data[k].float_data.resize(grpc_vector_data.float_data_size()); - memcpy(vector_data[k].float_data.data(), grpc_vector_data.float_data().data(), - grpc_vector_data.float_data_size() * sizeof(float)); - } else if (!grpc_vector_data.binary_data().empty()) { - vector_data[k].binary_data.resize(grpc_vector_data.binary_data().size() / 8); - memcpy(vector_data[k].binary_data.data(), grpc_vector_data.binary_data().data(), - grpc_vector_data.binary_data().size()); + for (int64_t k = 0; k < topk; k++) { + std::unordered_map scalar_data; + std::unordered_map vector_data; + if (grpc_entity.valid_row(i * topk + k)) { + for (int64_t j = 0; j < grpc_entity.fields_size(); j++) { + const auto& grpc_field = grpc_entity.fields(j); + if (grpc_field.has_attr_record()) { + if (grpc_field.attr_record().int32_value_size() > 0) { + scalar_data.insert({grpc_field.field_name(), grpc_field.attr_record().int32_value(offset)}); + } else if (grpc_field.attr_record().int64_value_size() > 0) { + scalar_data.insert({grpc_field.field_name(), grpc_field.attr_record().int64_value(offset)}); + } else if (grpc_field.attr_record().float_value_size() > 0) { + scalar_data.insert({grpc_field.field_name(), grpc_field.attr_record().float_value(offset)}); + } else { + scalar_data.insert( + {grpc_field.field_name(), grpc_field.attr_record().double_value(offset)}); } + } else { + auto float_size = grpc_field.vector_record().records(offset).float_data_size(); + auto bin_size = grpc_field.vector_record().records(offset).binary_data().size(); + milvus::VectorData vectors; + if (float_size > 0) { + std::vector float_data(float_size); + memcpy(float_data.data(), grpc_field.vector_record().records(offset).float_data().data(), + sizeof(float) * float_size); + vectors.float_data = float_data; + } else if (bin_size > 0) { + std::vector bin_data(bin_size / 8); + memcpy(bin_data.data(), grpc_field.vector_record().records(offset).binary_data().data(), + bin_size); + vectors.binary_data = bin_data; + } + vector_data.insert({grpc_field.field_name(), vectors}); } - one_result.field_value.vector_value.insert(std::make_pair(grpc_field.field_name(), vector_data)); + } + if (!scalar_data.empty() || !vector_data.empty()) { + Entity entity; + entity.entity_id = grpc_entity.ids(i * topk + k); + entity.scalar_data = scalar_data; + entity.vector_data = vector_data; + one_result.entities.emplace_back(entity); + offset++; } } } + topk_query_result.emplace_back(one_result); } } @@ -225,6 +209,57 @@ CopyFieldValue(const FieldValue& field_value, ::milvus::grpc::InsertParam& inser } } +void +CopyEntities(::milvus::grpc::Entities& grpc_entities, Entities& entities) { + auto grpc_field_size = grpc_entities.fields_size(); + std::vector field_names(grpc_field_size); + for (int64_t i = 0; i < grpc_field_size; i++) { + field_names[i] = grpc_entities.fields(i).field_name(); + } + + int row_num = grpc_entities.ids_size(); + int64_t offset = 0; + for (int64_t i = 0; i < row_num; i++) { + if (!grpc_entities.valid_row(i)) { + continue; + } + milvus::Entity entity = milvus::Entity(); + entity.entity_id = grpc_entities.ids(i); + for (int64_t j = 0; j < grpc_field_size; j++) { + const auto& grpc_field = grpc_entities.fields(j); + auto field_name = grpc_field.field_name(); + if (grpc_field.has_attr_record()) { + const auto& grpc_attr_record = grpc_field.attr_record(); + if (grpc_attr_record.int32_value_size() > 0) { + entity.scalar_data.insert({field_name, grpc_attr_record.int32_value(offset)}); + } else if (grpc_attr_record.int64_value_size() > 0) { + entity.scalar_data.insert({field_name, grpc_attr_record.int64_value(offset)}); + } else if (grpc_attr_record.float_value_size() > 0) { + entity.scalar_data.insert({field_name, grpc_attr_record.float_value(offset)}); + } else if (grpc_attr_record.double_value_size() > 0) { + entity.scalar_data.insert({field_name, grpc_attr_record.double_value(offset)}); + } + } else if (grpc_field.has_vector_record()) { + const auto& grpc_vector_record = grpc_field.vector_record(); + const auto& record = grpc_vector_record.records(offset); + milvus::VectorData vector_data; + if (record.float_data_size() > 0) { + std::vector data(record.float_data_size()); + memcpy(data.data(), record.float_data().data(), record.float_data_size() * sizeof(float)); + vector_data.float_data = data; + } else if (record.binary_data().size() > 0) { + std::vector data(record.binary_data().size()); + memcpy(data.data(), record.binary_data().data(), record.binary_data().size()); + vector_data.binary_data = data; + } + entity.vector_data.insert({field_name, vector_data}); + } + } + entities.emplace_back(entity); + offset++; + } +} + void CopyEntityToJson(::milvus::grpc::Entities& grpc_entities, JSON& json_entity) { int i; @@ -401,29 +436,23 @@ ClientProxy::Disconnect() { } Status -ClientProxy::CreateCollection(const Mapping& mapping, const std::string& extra_params) { +ClientProxy::CreateCollection(const Mapping& mapping) { CLIENT_NULL_CHECK(client_ptr_); try { ::milvus::grpc::Mapping grpc_mapping; grpc_mapping.set_collection_name(mapping.collection_name); for (auto& field : mapping.fields) { auto grpc_field = grpc_mapping.add_fields(); - grpc_field->set_name(field->field_name); - grpc_field->set_type((::milvus::grpc::DataType)field->field_type); - JSON json_index_param = JSON::parse(field->index_params); - for (auto& json_param : json_index_param.items()) { - auto grpc_index_param = grpc_field->add_index_params(); - grpc_index_param->set_key(json_param.key()); - grpc_index_param->set_value(json_param.value()); - } + grpc_field->set_name(field->name); + grpc_field->set_type((::milvus::grpc::DataType)field->type); auto grpc_extra_param = grpc_field->add_extra_params(); grpc_extra_param->set_key(EXTRA_PARAM_KEY); - grpc_extra_param->set_value(field->extra_params); + grpc_extra_param->set_value(field->params); } auto grpc_param = grpc_mapping.add_extra_params(); grpc_param->set_key(EXTRA_PARAM_KEY); - grpc_param->set_value(extra_params); + grpc_param->set_value(mapping.extra_params); return client_ptr_->CreateCollection(grpc_mapping); } catch (std::exception& ex) { @@ -485,7 +514,7 @@ ClientProxy::GetCollectionInfo(const std::string& collection_name, Mapping& mapp for (int64_t i = 0; i < grpc_mapping.fields_size(); i++) { const auto& grpc_field = grpc_mapping.fields(i); FieldPtr field_ptr = std::make_shared(); - field_ptr->field_name = grpc_field.name(); + field_ptr->name = grpc_field.name(); JSON json_index_params; for (int64_t j = 0; j < grpc_field.index_params_size(); j++) { JSON json_param; @@ -499,8 +528,8 @@ ClientProxy::GetCollectionInfo(const std::string& collection_name, Mapping& mapp json_param = JSON::parse(grpc_field.extra_params(j).value()); json_extra_params.emplace_back(json_param); } - field_ptr->extra_params = json_extra_params.dump(); - field_ptr->field_type = (DataType)grpc_field.type(); + field_ptr->params = json_extra_params.dump(); + field_ptr->type = (DataType)grpc_field.type(); mapping.fields.emplace_back(field_ptr); } if (!grpc_mapping.extra_params().empty()) { @@ -676,7 +705,7 @@ ClientProxy::Insert(const std::string& collection_name, const std::string& parti Status ClientProxy::GetEntityByID(const std::string& collection_name, const std::vector& id_array, - std::string& entities) { + Entities& entities) { CLIENT_NULL_CHECK(client_ptr_); try { ::milvus::grpc::EntityIdentity entity_identity; @@ -691,9 +720,7 @@ ClientProxy::GetEntityByID(const std::string& collection_name, const std::vector return status; } - JSON json_entities; - CopyEntityToJson(grpc_entities, json_entities); - entities = json_entities.dump(); + CopyEntities(grpc_entities, entities); return status; } catch (std::exception& ex) { return Status(StatusCode::UnknownError, "Failed to get entity by id: " + std::string(ex.what())); diff --git a/sdk/grpc/ClientProxy.h b/sdk/grpc/ClientProxy.h index 65d1cfa314..bb8257cc24 100644 --- a/sdk/grpc/ClientProxy.h +++ b/sdk/grpc/ClientProxy.h @@ -36,7 +36,7 @@ class ClientProxy : public Connection { Disconnect() override; Status - CreateCollection(const Mapping& mapping, const std::string& extra_params) override; + CreateCollection(const Mapping& mapping) override; Status DropCollection(const std::string& collection_name) override; @@ -81,7 +81,7 @@ class ClientProxy : public Connection { Status GetEntityByID(const std::string& collection_name, const std::vector& id_array, - std::string& entities) override; + Entities& entities) override; Status DeleteEntityByID(const std::string& collection_name, const std::vector& id_array) override; diff --git a/sdk/include/Field.h b/sdk/include/Field.h index 542520aa05..19f4d3adc1 100644 --- a/sdk/include/Field.h +++ b/sdk/include/Field.h @@ -16,7 +16,6 @@ #include #include "Status.h" - namespace milvus { enum class DataType { @@ -40,11 +39,16 @@ enum class DataType { // Base struct of all fields struct Field { - uint64_t field_id; ///< read-only - std::string field_name; - DataType field_type; - std::string index_params; - std::string extra_params; + std::string name; + DataType type; + std::string index_params = "{}"; + std::string params; + + Field(const std::string& _name, const DataType _type, const std::string& _params) + : name(_name), type(_type), params(_params) { + } + + Field() = default; }; using FieldPtr = std::shared_ptr; @@ -63,4 +67,4 @@ struct VectorField : Field { }; using VectorFieldPtr = std::shared_ptr; -} // namespace milvus +} // namespace milvus diff --git a/sdk/include/MilvusApi.h b/sdk/include/MilvusApi.h index a16a814c50..84fc2577a8 100644 --- a/sdk/include/MilvusApi.h +++ b/sdk/include/MilvusApi.h @@ -76,6 +76,13 @@ struct AttrRecord { /** * @brief field value */ +struct Entity { + int64_t entity_id; + std::unordered_map scalar_data; + std::unordered_map vector_data; +}; +using Entities = std::vector; + struct FieldValue { int64_t row_num; std::unordered_map> int32_value; @@ -99,6 +106,7 @@ struct VectorParam { struct QueryResult { std::vector ids; ///< Query entity ids result std::vector distances; ///< Query distances result + Entities entities; FieldValue field_value; }; using TopKQueryResult = std::vector; ///< Topk hybrid query result @@ -229,7 +237,7 @@ class Connection { * @return Indicate if collection is created successfully */ virtual Status - CreateCollection(const Mapping& mapping, const std::string& extra_params) = 0; + CreateCollection(const Mapping& mapping) = 0; /** * @brief Drop collection method @@ -416,7 +424,7 @@ class Connection { * @return Indicate if the operation is succeed. */ virtual Status - GetEntityByID(const std::string& collection_name, const std::vector& id_array, std::string& entities) = 0; + GetEntityByID(const std::string& collection_name, const std::vector& id_array, Entities& entities) = 0; /** * @brief Delete entity by id diff --git a/sdk/interface/ConnectionImpl.cpp b/sdk/interface/ConnectionImpl.cpp index 95b565f85e..be0724d1e9 100644 --- a/sdk/interface/ConnectionImpl.cpp +++ b/sdk/interface/ConnectionImpl.cpp @@ -52,8 +52,8 @@ ConnectionImpl::Disconnect() { } Status -ConnectionImpl::CreateCollection(const Mapping& mapping, const std::string& extra_params) { - return client_proxy_->CreateCollection(mapping, extra_params); +ConnectionImpl::CreateCollection(const Mapping& mapping) { + return client_proxy_->CreateCollection(mapping); } Status @@ -124,7 +124,7 @@ ConnectionImpl::Insert(const std::string& collection_name, const std::string& pa Status ConnectionImpl::GetEntityByID(const std::string& collection_name, const std::vector& id_array, - std::string& entities) { + Entities& entities) { return client_proxy_->GetEntityByID(collection_name, id_array, entities); } diff --git a/sdk/interface/ConnectionImpl.h b/sdk/interface/ConnectionImpl.h index bf7862e476..fad94334b4 100644 --- a/sdk/interface/ConnectionImpl.h +++ b/sdk/interface/ConnectionImpl.h @@ -53,7 +53,7 @@ class ConnectionImpl : public Connection { // SetConfig(const std::string& node_name, const std::string& value) const override; Status - CreateCollection(const Mapping& mapping, const std::string& extra_params) override; + CreateCollection(const Mapping& mapping) override; Status DropCollection(const std::string& collection_name) override; @@ -98,7 +98,7 @@ class ConnectionImpl : public Connection { Status GetEntityByID(const std::string& collection_name, const std::vector& id_array, - std::string& entities) override; + Entities& entities) override; Status DeleteEntityByID(const std::string& collection_name, const std::vector& id_array) override;