When execute a multiple vectors request, if the result of somerow is not enough, the query data field will be disorder. (#4140)

* Change returned field_value to entities in C++ sdk

Signed-off-by: fishpenguin <kun.yu@zilliz.com>

* When execute a multiple vectors request, if the result of somerow is not enough, the query data field will be disorder.

Signed-off-by: fishpenguin <kun.yu@zilliz.com>
This commit is contained in:
yukun 2020-10-30 19:22:00 +08:00 committed by GitHub
parent a701c2f1de
commit 4b3a4eddeb
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
13 changed files with 367 additions and 351 deletions

View File

@ -8,6 +8,7 @@ Please mark all changes in change log and use the issue from GitHub
- \#3793 Frequent timeout when running test cases in parallel
- \#4030 The configurations name is not effective after update config file and restart docker
- \#4059 milvus raise exception when passing a complex dsl during searching
- \#4099 When execute a multiple vectors request, if the result of somerow is not enough, the query data field will be disorder.
- \#4113 After milvus 0.11.0 version vector data is misaligned
## Feature

View File

@ -1752,10 +1752,7 @@ GrpcRequestHandler::DeserializeJsonToBoolQuery(
if (dsl_json.contains("bool")) {
auto boolean_query_json = dsl_json["bool"];
JSON_NULL_CHECK(boolean_query_json);
status = ProcessBooleanQueryJson(boolean_query_json, boolean_query, query_ptr);
if (!status.ok()) {
return Status(SERVER_INVALID_DSL_PARAMETER, "DSL does not include bool");
}
STATUS_CHECK(ProcessBooleanQueryJson(boolean_query_json, boolean_query, query_ptr));
} else {
return Status(SERVER_INVALID_DSL_PARAMETER, "DSL does not include bool query");
}

View File

@ -43,7 +43,7 @@ message(STATUS "Build type = ${BUILD_TYPE}")
unset(CMAKE_EXPORT_COMPILE_COMMANDS CACHE)
set(CMAKE_EXPORT_COMPILE_COMMANDS ON)
set(CMAKE_CXX_STANDARD 14)
set(CMAKE_CXX_STANDARD 17)
set(CMAKE_CXX_STANDARD_REQUIRED on)
if (CMAKE_SYSTEM_PROCESSOR MATCHES "(x86)|(X86)|(amd64)|(AMD64)")

View File

@ -23,17 +23,17 @@
namespace {
const char* COLLECTION_NAME = milvus_sdk::Utils::GenCollectionName().c_str();
const char* PARTITION_TAG = "American";
constexpr int64_t COLLECTION_DIMENSION = 512;
constexpr int64_t COLLECTION_DIMENSION = 8;
constexpr milvus::MetricType COLLECTION_METRIC_TYPE = milvus::MetricType::L2;
constexpr int64_t BATCH_ENTITY_COUNT = 10000;
constexpr int64_t NQ = 5;
constexpr int64_t TOP_K = 10;
constexpr int64_t BATCH_ENTITY_COUNT = 3;
constexpr int64_t NQ = 1;
constexpr int64_t TOP_K = 3;
constexpr int64_t NPROBE = 16;
constexpr int64_t SEARCH_TARGET = BATCH_ENTITY_COUNT / 2; // change this value, result is different
constexpr int64_t ADD_ENTITY_LOOP = 10;
constexpr int64_t ADD_ENTITY_LOOP = 1;
constexpr int32_t NLIST = 1024;
const char* PARTITION_TAG = "part";
const char* DIMENSION = "dim";
const char* METRICTYPE = "metric_type";
const char* INDEXTYPE = "index_type";
@ -77,87 +77,85 @@ ClientTest::ListCollections(std::vector<std::string>& collection_array) {
}
void
ClientTest::CreateCollection(const std::string& collection_name) {
milvus::FieldPtr field_ptr1 = std::make_shared<milvus::Field>();
milvus::FieldPtr field_ptr2 = std::make_shared<milvus::Field>();
milvus::FieldPtr field_ptr4 = std::make_shared<milvus::Field>();
ClientTest::CreateCollection() {
milvus::FieldPtr field1 = std::make_shared<milvus::Field>("release_year", milvus::DataType::INT32, "");
milvus::FieldPtr field2 = std::make_shared<milvus::Field>("duration", milvus::DataType::INT32, "");
nlohmann::json vector_param = {{"dim", COLLECTION_DIMENSION}};
milvus::FieldPtr field3 =
std::make_shared<milvus::Field>("embedding", milvus::DataType::VECTOR_FLOAT, vector_param.dump());
field_ptr1->field_name = "field_1";
field_ptr1->field_type = milvus::DataType::INT64;
JSON index_param_1;
index_param_1["name"] = "index_1";
field_ptr1->index_params = index_param_1.dump();
nlohmann::json json_param;
json_param = {{"auto_id", false}, {"segment_row_limit", 4096}};
milvus::Mapping mapping = {COLLECTION_NAME, {field1, field2, field3}, json_param.dump()};
field_ptr2->field_name = "field_2";
field_ptr2->field_type = milvus::DataType::FLOAT;
JSON index_param_2;
index_param_2["name"] = "index_2";
field_ptr2->index_params = index_param_2.dump();
field_ptr4->field_name = "field_vec";
field_ptr4->field_type = milvus::DataType::VECTOR_FLOAT;
JSON index_param_4;
index_param_4["name"] = "index_vec";
field_ptr4->index_params = index_param_4.dump();
JSON extra_params_4;
extra_params_4[DIMENSION] = COLLECTION_DIMENSION;
field_ptr4->extra_params = extra_params_4.dump();
JSON extra_params;
extra_params["segment_row_limit"] = 10000;
extra_params["auto_id"] = false;
milvus::Mapping mapping = {collection_name, {field_ptr1, field_ptr2, field_ptr4}};
milvus::Status stat = conn_->CreateCollection(mapping, extra_params.dump());
std::cout << "CreateCollection function call status: " << stat.message() << std::endl;
milvus::Status status = conn_->CreateCollection(mapping);
std::cout << "CreateCollection function call status: " << status.message() << std::endl;
}
void
ClientTest::GetCollectionInfo(const std::string& collection_name) {
ClientTest::CreatePartition() {
milvus::PartitionParam param = milvus::PartitionParam{COLLECTION_NAME, PARTITION_TAG};
auto status = conn_->CreatePartition(param);
std::cout << "CreatePartition function call status: " << status.message() << std::endl;
}
void
ClientTest::GetCollectionInfo() {
milvus::Mapping mapping;
milvus::Status stat = conn_->GetCollectionInfo(collection_name, mapping);
milvus::Status status = conn_->GetCollectionInfo(COLLECTION_NAME, mapping);
milvus_sdk::Utils::PrintMapping(mapping);
std::cout << "GetCollectionInfo function call status: " << stat.message() << std::endl;
std::cout << "GetCollectionInfo function call status: " << status.message() << std::endl;
}
void
ClientTest::InsertEntities(const std::string& collection_name) {
for (int64_t i = 0; i < ADD_ENTITY_LOOP; i++) {
milvus::FieldValue field_value;
std::vector<int64_t> entity_ids;
int64_t begin_index = i * BATCH_ENTITY_COUNT;
{
milvus_sdk::TimeRecorder rc("Build entities No." + std::to_string(i));
milvus_sdk::Utils::BuildEntities(begin_index, begin_index + BATCH_ENTITY_COUNT, field_value, entity_ids,
COLLECTION_DIMENSION);
}
milvus::Status status = conn_->Insert(collection_name, "", field_value, entity_ids);
search_id_array_.emplace_back(entity_ids[10]);
std::cout << "InsertEntities function call status: " << status.message() << std::endl;
std::cout << "Returned id array count: " << entity_ids.size() << std::endl;
ClientTest::ListPartitions() {
milvus::PartitionTagList partition_list;
auto status = conn_->ListPartitions(COLLECTION_NAME, partition_list);
std::cout << "Partitions: ";
for (const auto& part : partition_list) {
std::cout << part << std::endl;
}
std::cout << "ListPartitions function call status: " << status.message() << std::endl;
}
void
ClientTest::CountEntities(const std::string& collection_name) {
int64_t entity_count = 0;
auto status = conn_->CountEntities(collection_name, entity_count);
std::cout << "Collection " << collection_name << " entity count: " << entity_count << std::endl;
ClientTest::InsertEntities() {
std::vector<int32_t> duration{208, 226, 252};
std::vector<int32_t> release_year{2001, 2002, 2003};
std::vector<milvus::VectorData> embedding;
milvus_sdk::Utils::BuildVectors(COLLECTION_DIMENSION, 3, embedding);
milvus::FieldValue field_value;
std::unordered_map<std::string, std::vector<int32_t>> int32_value = {{"duration", duration},
{"release_year", release_year}};
std::unordered_map<std::string, std::vector<milvus::VectorData>> vector_value = {{"embedding", embedding}};
field_value.int32_value = int32_value;
field_value.vector_value = vector_value;
std::vector<int64_t> id_array = {1, 2, 3};
auto status = conn_->Insert(COLLECTION_NAME, PARTITION_TAG, field_value, id_array);
std::cout << "InsertEntities function call status: " << status.message() << std::endl;
}
void
ClientTest::Flush(const std::string& collection_name) {
ClientTest::CountEntities(int64_t& entity_count) {
auto status = conn_->CountEntities(COLLECTION_NAME, entity_count);
std::cout << "Collection " << COLLECTION_NAME << " entity count: " << entity_count << std::endl;
}
void
ClientTest::Flush() {
milvus_sdk::TimeRecorder rc("Flush");
std::vector<std::string> collections = {collection_name};
std::vector<std::string> collections = {COLLECTION_NAME};
milvus::Status stat = conn_->Flush(collections);
std::cout << "Flush function call status: " << stat.message() << std::endl;
}
void
ClientTest::GetCollectionStats(const std::string& collection_name) {
ClientTest::GetCollectionStats() {
std::string collection_stats;
milvus::Status stat = conn_->GetCollectionStats(collection_name, collection_stats);
milvus::Status stat = conn_->GetCollectionStats(COLLECTION_NAME, collection_stats);
std::cout << "Collection stats: " << collection_stats << std::endl;
std::cout << "GetCollectionStats function call status: " << stat.message() << std::endl;
}
@ -178,76 +176,53 @@ ClientTest::BuildVectors(int64_t nq, int64_t dimension) {
}
void
ClientTest::GetEntityByID(const std::string& collection_name, const std::vector<int64_t>& id_array) {
std::string result;
ClientTest::GetEntityByID(const std::vector<int64_t>& id_array) {
milvus::Entities entities;
{
milvus_sdk::TimeRecorder rc("GetEntityByID");
milvus::Status stat = conn_->GetEntityByID(collection_name, id_array, result);
milvus::Status stat = conn_->GetEntityByID(COLLECTION_NAME, id_array, entities);
std::cout << "GetEntityByID function call status: " << stat.message() << std::endl;
}
};
std::cout << "GetEntityByID function result: " << std::endl;
JSON result_json = JSON::parse(result);
for (const auto& one_result : result_json) {
std::cout << one_result << std::endl;
for (const auto& entity : entities) {
std::cout << "Entity id: " << entity.entity_id << std::endl;
for (const auto& data : entity.scalar_data) {
if (data.first == "duration" || data.first == "release_year") {
std::cout << data.first << ": " << std::any_cast<int32_t>(data.second) << std::endl;
}
}
for (const auto& data : entity.vector_data) {
auto embedding = data.second.float_data;
std::cout << data.first << ":";
for (const auto& v : embedding) {
std::cout << v << " ";
}
std::cout << std::endl;
}
}
}
void
ClientTest::SearchEntities(const std::string& collection_name, int64_t topk, int64_t nprobe,
const std::string metric_type) {
ClientTest::SearchEntities() {
nlohmann::json dsl_json, vector_param_json;
milvus_sdk::Utils::GenDSLJson(dsl_json, vector_param_json, metric_type);
milvus_sdk::Utils::GenDSLJson(dsl_json, vector_param_json, TOP_K, "L2");
std::vector<int64_t> record_ids;
std::vector<milvus::VectorData> temp_entity_array;
for (auto& pair : search_entity_array_) {
temp_entity_array.push_back(pair.second);
}
std::vector<milvus::VectorData> query_embedding;
milvus_sdk::Utils::BuildVectors(COLLECTION_DIMENSION, 1, query_embedding);
milvus::VectorParam vector_param = {vector_param_json.dump(), temp_entity_array};
milvus::VectorParam vector_param = {vector_param_json.dump(), query_embedding};
std::vector<std::string> get_fields{"duration", "release_year", "embedding"};
nlohmann::json json_params = {{"fields", get_fields}};
std::vector<std::string> partition_tags;
milvus::TopKQueryResult topk_query_result;
auto status = conn_->Search(collection_name, partition_tags, dsl_json.dump(), vector_param, "", topk_query_result);
auto status = conn_->Search(COLLECTION_NAME, partition_tags, dsl_json.dump(), vector_param, json_params.dump(),
topk_query_result);
std::cout << metric_type << " Search function call result: " << std::endl;
std::cout << " Search function call result: " << std::endl;
milvus_sdk::Utils::PrintTopKQueryResult(topk_query_result);
std::cout << metric_type << " Search function call status: " << status.message() << std::endl;
}
void
ClientTest::SearchEntitiesByID(const std::string& collection_name, int64_t topk, int64_t nprobe) {
// std::vector<std::string> partition_tags;
// milvus::TopKQueryResult topk_query_result;
//
// topk_query_result.clear();
//
// std::vector<int64_t> id_array;
// for (auto& pair : search_entity_array_) {
// id_array.push_back(pair.first);
// }
//
// std::vector<milvus::Entity> entities;
// milvus::Status stat = conn_->GetEntityByID(collection_name, id_array, entities);
// std::cout << "GetEntityByID function call status: " << stat.message() << std::endl;
//
// JSON json_params = {{"nprobe", nprobe}};
// milvus_sdk::TimeRecorder rc("Search");
// stat = conn_->Search(collection_name, partition_tags, entities, topk, json_params.dump(), topk_query_result);
// std::cout << "Search function call status: " << stat.message() << std::endl;
//
// if (topk_query_result.size() != id_array.size()) {
// std::cout << "ERROR! wrong result for query by id" << std::endl;
// return;
// }
//
// for (size_t i = 0; i < id_array.size(); i++) {
// std::cout << "Entity " << id_array[i] << " top " << topk << " search result:" << std::endl;
// const milvus::QueryResult& one_result = topk_query_result[i];
// for (size_t j = 0; j < one_result.ids.size(); j++) {
// std::cout << "\t" << one_result.ids[j] << "\t" << one_result.distances[j] << std::endl;
// }
// }
}
void
@ -278,17 +253,16 @@ ClientTest::CompactCollection(const std::string& collection_name) {
}
void
ClientTest::DeleteByIds(const std::string& collection_name, const std::vector<int64_t>& id_array) {
std::cout << "Delete entity: ";
for (auto id : id_array) {
std::cout << "\t" << id;
}
std::cout << std::endl;
ClientTest::DeleteByIds(const std::vector<int64_t>& id_array) {
auto status = conn_->DeleteEntityByID(COLLECTION_NAME, id_array);
std::cout << "DeleteByID function call status: " << status.message() << std::endl;
}
milvus::Status stat = conn_->DeleteEntityByID(collection_name, id_array);
std::cout << "DeleteByID function call status: " << stat.message() << std::endl;
Flush(collection_name);
void
ClientTest::DropPartition() {
milvus::PartitionParam param = {COLLECTION_NAME, PARTITION_TAG};
auto status = conn_->DropPartition(param);
std::cout << "DropPartition function call status: " << status.message() << std::endl;
}
void
@ -310,43 +284,53 @@ ClientTest::DropCollection(const std::string& collection_name) {
void
ClientTest::Test() {
std::string collection_name = COLLECTION_NAME;
int64_t dim = COLLECTION_DIMENSION;
milvus::MetricType metric_type = COLLECTION_METRIC_TYPE;
std::vector<std::string> collection_array;
ListCollections(collection_array);
for (const auto& collection : collection_array) {
DropCollection(collection);
}
std::vector<std::string> table_array;
ListCollections(table_array);
CreateCollection();
CreatePartition();
CreateCollection(collection_name);
// GetCollectionInfo(collection_name);
GetCollectionStats(collection_name);
std::cout << "--------get collection info--------" << std::endl;
GetCollectionInfo();
std::cout << "\n----------list partitions----------" << std::endl;
ListPartitions();
ListCollections(table_array);
CountEntities(collection_name);
std::cout << "\n----------insert----------" << std::endl;
InsertEntities();
InsertEntities(collection_name);
Flush(collection_name);
CountEntities(collection_name);
CreateIndex(collection_name, NLIST);
GetCollectionInfo(collection_name);
// GetCollectionStats(collection_name);
//
LoadCollection(COLLECTION_NAME);
BuildVectors(NQ, COLLECTION_DIMENSION);
// GetEntityByID(collection_name, search_id_array_);
SearchEntities(collection_name, TOP_K, NPROBE, "L2");
SearchEntities(collection_name, TOP_K, NPROBE, "IP");
// GetCollectionStats(collection_name);
//
// std::vector<int64_t> delete_ids = {search_id_array_[0], search_id_array_[1]};
// DeleteByIds(collection_name, delete_ids);
// GetEntityByID(collection_name, search_id_array_);
// CompactCollection(collection_name);
//
// LoadCollection(collection_name);
// SearchEntities(collection_name, TOP_K, NPROBE); // this line get two search error since we delete two
// entities
//
// DropIndex(collection_name, "field_vec", "index_3");
DropCollection(collection_name);
int64_t before_flush_counts = 0;
int64_t after_flush_counts = 0;
CountEntities(before_flush_counts);
Flush();
CountEntities(after_flush_counts);
std::cout << "\n----------flush----------" << std::endl;
std::cout << "There are " << before_flush_counts << " films in collection " << COLLECTION_NAME << " before flush"
<< std::endl;
std::cout << "There are " << after_flush_counts << " films in collection " << COLLECTION_NAME << " after flush"
<< std::endl;
std::cout << "\n----------get collection stats----------\n";
GetCollectionStats();
std::cout << "\n----------get entity by id = 1, id = 200----------\n";
std::vector<int64_t> id_array = {1, 200};
GetEntityByID(id_array);
std::cout << "\n----------search----------\n";
SearchEntities();
std::vector<int64_t> delete_id_array = {1, 2};
std::cout << "\n----------delete id = 1, id = 2----------\n";
DeleteByIds(delete_id_array);
Flush();
GetEntityByID(delete_id_array);
int64_t counts_in_collection;
CountEntities(counts_in_collection);
std::cout << "There are " << counts_in_collection << " entities after delete films with 1, 2\n";
DropCollection(COLLECTION_NAME);
}

View File

@ -31,34 +31,37 @@ class ClientTest {
ListCollections(std::vector<std::string>&);
void
CreateCollection(const std::string&);
CreateCollection();
void
GetCollectionInfo(const std::string&);
CreatePartition();
void
InsertEntities(const std::string&);
GetCollectionInfo();
void
CountEntities(const std::string&);
ListPartitions();
void
Flush(const std::string&);
InsertEntities();
void
GetCollectionStats(const std::string&);
CountEntities(int64_t& count);
void
Flush();
void
GetCollectionStats();
void
BuildVectors(int64_t nq, int64_t dimension);
void
GetEntityByID(const std::string&, const std::vector<int64_t>&);
GetEntityByID(const std::vector<int64_t>& id_array);
void
SearchEntities(const std::string&, int64_t, int64_t, const std::string metric_type);
void
SearchEntitiesByID(const std::string&, int64_t, int64_t);
SearchEntities();
void
CreateIndex(const std::string&, int64_t);
@ -70,7 +73,10 @@ class ClientTest {
CompactCollection(const std::string&);
void
DeleteByIds(const std::string&, const std::vector<int64_t>& id_array);
DeleteByIds(const std::vector<int64_t>& id_array);
void
DropPartition();
void
DropIndex(const std::string& collection_name, const std::string& field_name, const std::string& index_name);

View File

@ -137,10 +137,10 @@ Utils::PrintCollectionParam(const milvus::Mapping& mapping) {
BLOCK_SPLITER
std::cout << "Collection name: " << mapping.collection_name << std::endl;
for (const auto& field : mapping.fields) {
std::cout << "field_name: " << field->field_name;
std::cout << "\tfield_type: " << std::to_string((int)field->field_type);
std::cout << "field_name: " << field->name;
std::cout << "\tfield_type: " << std::to_string((int)field->type);
std::cout << "\tindex_param: " << field->index_params;
std::cout << "\textra_param:" << field->extra_params << std::endl;
std::cout << "\textra_param:" << field->params << std::endl;
}
BLOCK_SPLITER
}
@ -167,14 +167,27 @@ Utils::PrintMapping(const milvus::Mapping& mapping) {
BLOCK_SPLITER
std::cout << "Collection name: " << mapping.collection_name << std::endl;
for (const auto& field : mapping.fields) {
std::cout << "field name: " << field->field_name << "\t field type: " << (int32_t)field->field_type
<< "\t field index params:" << field->index_params << "\t field extra params: " << field->extra_params
std::cout << "field name: " << field->name << "\t field type: " << (int32_t)field->type
<< "\t field index params:" << field->index_params << "\t field extra params: " << field->params
<< std::endl;
}
std::cout << "Collection extra params: " << mapping.extra_params << std::endl;
BLOCK_SPLITER
}
void
Utils::BuildVectors(int64_t dim, int64_t nb, std::vector<milvus::VectorData>& vectors) {
std::default_random_engine e;
std::uniform_real_distribution<float> u(0, 1);
vectors.resize(nb);
for (int64_t i = 0; i < nb; i++) {
vectors[i].float_data.resize(dim);
for (int64_t j = 0; j < dim; j++) {
vectors[i].float_data[j] = u(e);
}
}
}
void
Utils::BuildEntities(int64_t from, int64_t to, milvus::FieldValue& field_value, std::vector<int64_t>& entity_ids,
int64_t dimension) {
@ -336,39 +349,23 @@ Utils::GenLeafQuery() {
}
void
Utils::GenDSLJson(nlohmann::json& dsl_json, nlohmann::json& vector_param_json, const std::string metric_type) {
uint64_t row_num = 10000;
std::vector<int64_t> term_value;
term_value.resize(row_num);
for (uint64_t i = 0; i < row_num; ++i) {
term_value[i] = i;
}
Utils::GenDSLJson(nlohmann::json& dsl_json, nlohmann::json& vector_param_json, int64_t topk,
const std::string& metric_type) {
std::vector<int64_t> term_vale = {2002, 2003};
std::vector<std::vector<float>> embedding;
// dsl_json = {{"bool",
// {"must",
// {{"term", {"release_year", {2002, 2003}}},
// {{"range", {{"duration", {"GT", 250}}}}},
// {{"vector", "placeholder_1"}}}}}};
JSON must_query, term_query, range_query, vector_query;
term_query["term"]["release_year"] = {2002, 2003};
range_query["range"]["duration"] = {{"GT", 250}};
vector_query["vector"] = "placeholder_1";
must_query["must"] = {term_query, range_query, vector_query};
dsl_json["bool"] = must_query;
nlohmann::json bool_json, term_json, range_json, vector_json;
nlohmann::json term_value_json;
term_value_json["values"] = term_value;
term_json["term"]["field_1"] = term_value_json;
bool_json["must"].push_back(term_json);
nlohmann::json comp_json;
comp_json["GT"] = 0;
comp_json["LT"] = 100000;
range_json["range"]["field_1"] = comp_json;
bool_json["must"].push_back(range_json);
std::string placeholder = "placeholder_1";
vector_json["vector"] = placeholder;
bool_json["must"].push_back(vector_json);
dsl_json["bool"] = bool_json;
nlohmann::json query_vector_json, vector_extra_params;
int64_t topk = 10;
query_vector_json["topk"] = topk;
query_vector_json["metric_type"] = metric_type;
vector_extra_params["nprobe"] = 64;
query_vector_json["params"] = vector_extra_params;
vector_param_json[placeholder]["field_vec"] = query_vector_json;
vector_param_json = {{"placeholder_1", {{"embedding", {{"topk", topk}, {"metric_type", metric_type}}}}}};
}
void
@ -392,39 +389,27 @@ Utils::GenPureVecDSLJson(nlohmann::json& dsl_json, nlohmann::json& vector_param_
void
Utils::PrintTopKQueryResult(milvus::TopKQueryResult& topk_query_result) {
for (size_t i = 0; i < topk_query_result.size(); i++) {
auto field_value = topk_query_result[i].field_value;
for (auto& int32_it : field_value.int32_value) {
std::cout << int32_it.first << ":";
for (auto& data : int32_it.second) {
std::cout << " " << data;
}
std::cout << std::endl;
}
for (auto& int64_it : field_value.int64_value) {
std::cout << int64_it.first << ":";
for (auto& data : int64_it.second) {
std::cout << " " << data;
}
std::cout << std::endl;
}
for (auto& float_it : field_value.float_value) {
std::cout << float_it.first << ":";
for (auto& data : float_it.second) {
std::cout << " " << data;
}
std::cout << std::endl;
}
for (auto& double_it : field_value.double_value) {
std::cout << double_it.first << ":";
for (auto& data : double_it.second) {
std::cout << " " << data;
}
std::cout << std::endl;
}
auto entities = topk_query_result[i].entities;
for (size_t j = 0; j < topk_query_result[i].ids.size(); j++) {
std::cout << topk_query_result[i].ids[j] << " --------- " << topk_query_result[i].distances[j]
<< std::endl;
std::cout << "- id: " << topk_query_result[i].ids[j] << std::endl;
std::cout << "- distance: " << topk_query_result[i].distances[j] << std::endl;
for (const auto& data : entities[j].scalar_data) {
if (data.first == "duration" || data.first == "release_year") {
std::cout << "- " << data.first << ": " << std::any_cast<int32_t>(data.second) << std::endl;
}
}
for (const auto& data : entities[j].vector_data) {
if (data.first == "embedding") {
std::cout << "- " << data.first << ": ";
for (const auto& v : data.second.float_data) {
std::cout << v << " ";
}
std::cout << std::endl;
}
}
}
std::cout << std::endl;
}
}

View File

@ -60,6 +60,9 @@ class Utils {
BuildEntities(int64_t from, int64_t to, milvus::FieldValue& field_value, std::vector<int64_t>& entity_ids,
int64_t dimension);
static void
BuildVectors(int64_t dim, int64_t nb, std::vector<milvus::VectorData>& vectors);
static void
PrintSearchResult(const std::vector<std::pair<int64_t, milvus::VectorData>>& entity_array,
const milvus::TopKQueryResult& topk_query_result);
@ -82,7 +85,8 @@ class Utils {
GenLeafQuery();
static void
GenDSLJson(nlohmann::json& dsl_json, nlohmann::json& vector_param_json, const std::string metric_type);
GenDSLJson(nlohmann::json& dsl_json, nlohmann::json& vector_param_json, int64_t topk,
const std::string& metric_type);
static void
GenPureVecDSLJson(nlohmann::json& dsl_json, nlohmann::json& vector_param_json, const std::string metric_type);

View File

@ -66,15 +66,21 @@ CopyRowRecord(::milvus::grpc::VectorRowRecord* target, const VectorData& src) {
}
void
ConstructTopkResult(const ::milvus::grpc::QueryResult& grpc_result, TopKQueryResult& topk_query_result) {
topk_query_result.reserve(grpc_result.row_num());
ConstructTopkQueryResult(const ::milvus::grpc::QueryResult& grpc_result, TopKQueryResult& topk_query_result) {
int64_t nq = grpc_result.row_num();
int64_t topk = grpc_result.entities().ids_size() / nq;
for (int64_t i = 0; i < nq; i++) {
milvus::QueryResult one_result;
if (nq == 0) {
return;
}
topk_query_result.reserve(nq);
const auto& grpc_entity = grpc_result.entities();
int64_t topk = grpc_entity.ids_size() / nq;
int64_t offset = 0;
for (int64_t i = 0; i < grpc_result.row_num(); i++) {
milvus::QueryResult one_result = milvus::QueryResult();
one_result.ids.resize(topk);
one_result.distances.resize(topk);
memcpy(one_result.ids.data(), grpc_result.entities().ids().data() + topk * i, topk * sizeof(int64_t));
memcpy(one_result.ids.data(), grpc_entity.ids().data() + topk * i, topk * sizeof(int64_t));
memcpy(one_result.distances.data(), grpc_result.distances().data() + topk * i, topk * sizeof(float));
int valid_size = one_result.ids.size();
@ -86,74 +92,52 @@ ConstructTopkResult(const ::milvus::grpc::QueryResult& grpc_result, TopKQueryRes
one_result.distances.resize(valid_size);
}
topk_query_result.emplace_back(one_result);
}
}
void
ConstructTopkQueryResult(const ::milvus::grpc::QueryResult& grpc_result, TopKQueryResult& topk_query_result) {
int64_t nq = grpc_result.row_num();
if (nq == 0) {
return;
}
topk_query_result.reserve(nq);
const auto& grpc_entity = grpc_result.entities();
int64_t topk = grpc_entity.ids_size() / nq;
// TODO(yukun): filter -1 results
for (int64_t i = 0; i < grpc_result.row_num(); i++) {
milvus::QueryResult one_result;
one_result.ids.resize(topk);
one_result.distances.resize(topk);
memcpy(one_result.ids.data(), grpc_entity.ids().data() + topk * i, topk * sizeof(int64_t));
memcpy(one_result.distances.data(), grpc_result.distances().data() + topk * i, topk * sizeof(float));
int64_t j;
for (j = 0; j < grpc_entity.fields_size(); j++) {
auto grpc_field = grpc_entity.fields(j);
if (grpc_field.has_attr_record()) {
if (grpc_field.attr_record().int32_value_size() > 0) {
std::vector<int32_t> int32_data(topk);
memcpy(int32_data.data(), grpc_field.attr_record().int32_value().data() + topk * i,
topk * sizeof(int32_t));
one_result.field_value.int32_value.insert(std::make_pair(grpc_field.field_name(), int32_data));
} else if (grpc_field.attr_record().int64_value_size() > 0) {
std::vector<int64_t> int64_data(topk);
memcpy(int64_data.data(), grpc_field.attr_record().int64_value().data() + topk * i,
topk * sizeof(int64_t));
one_result.field_value.int64_value.insert(std::make_pair(grpc_field.field_name(), int64_data));
} else if (grpc_field.attr_record().float_value_size() > 0) {
std::vector<float> float_data(topk);
memcpy(float_data.data(), grpc_field.attr_record().float_value().data() + topk * i,
topk * sizeof(float));
one_result.field_value.float_value.insert(std::make_pair(grpc_field.field_name(), float_data));
} else if (grpc_field.attr_record().double_value_size() > 0) {
std::vector<double> double_data(topk);
memcpy(double_data.data(), grpc_field.attr_record().double_value().data() + topk * i,
topk * sizeof(double));
one_result.field_value.double_value.insert(std::make_pair(grpc_field.field_name(), double_data));
}
}
if (grpc_field.has_vector_record()) {
int64_t vector_row_count = grpc_field.vector_record().records_size();
if (vector_row_count > 0) {
std::vector<VectorData> vector_data(topk);
for (int64_t k = topk * i; k < topk * (i + 1); k++) {
auto grpc_vector_data = grpc_field.vector_record().records(k);
if (grpc_vector_data.float_data_size() > 0) {
vector_data[k].float_data.resize(grpc_vector_data.float_data_size());
memcpy(vector_data[k].float_data.data(), grpc_vector_data.float_data().data(),
grpc_vector_data.float_data_size() * sizeof(float));
} else if (!grpc_vector_data.binary_data().empty()) {
vector_data[k].binary_data.resize(grpc_vector_data.binary_data().size() / 8);
memcpy(vector_data[k].binary_data.data(), grpc_vector_data.binary_data().data(),
grpc_vector_data.binary_data().size());
for (int64_t k = 0; k < topk; k++) {
std::unordered_map<std::string, std::any> scalar_data;
std::unordered_map<std::string, milvus::VectorData> vector_data;
if (grpc_entity.valid_row(i * topk + k)) {
for (int64_t j = 0; j < grpc_entity.fields_size(); j++) {
const auto& grpc_field = grpc_entity.fields(j);
if (grpc_field.has_attr_record()) {
if (grpc_field.attr_record().int32_value_size() > 0) {
scalar_data.insert({grpc_field.field_name(), grpc_field.attr_record().int32_value(offset)});
} else if (grpc_field.attr_record().int64_value_size() > 0) {
scalar_data.insert({grpc_field.field_name(), grpc_field.attr_record().int64_value(offset)});
} else if (grpc_field.attr_record().float_value_size() > 0) {
scalar_data.insert({grpc_field.field_name(), grpc_field.attr_record().float_value(offset)});
} else {
scalar_data.insert(
{grpc_field.field_name(), grpc_field.attr_record().double_value(offset)});
}
} else {
auto float_size = grpc_field.vector_record().records(offset).float_data_size();
auto bin_size = grpc_field.vector_record().records(offset).binary_data().size();
milvus::VectorData vectors;
if (float_size > 0) {
std::vector<float> float_data(float_size);
memcpy(float_data.data(), grpc_field.vector_record().records(offset).float_data().data(),
sizeof(float) * float_size);
vectors.float_data = float_data;
} else if (bin_size > 0) {
std::vector<uint8_t> bin_data(bin_size / 8);
memcpy(bin_data.data(), grpc_field.vector_record().records(offset).binary_data().data(),
bin_size);
vectors.binary_data = bin_data;
}
vector_data.insert({grpc_field.field_name(), vectors});
}
one_result.field_value.vector_value.insert(std::make_pair(grpc_field.field_name(), vector_data));
}
if (!scalar_data.empty() || !vector_data.empty()) {
Entity entity;
entity.entity_id = grpc_entity.ids(i * topk + k);
entity.scalar_data = scalar_data;
entity.vector_data = vector_data;
one_result.entities.emplace_back(entity);
offset++;
}
}
}
topk_query_result.emplace_back(one_result);
}
}
@ -225,6 +209,57 @@ CopyFieldValue(const FieldValue& field_value, ::milvus::grpc::InsertParam& inser
}
}
void
CopyEntities(::milvus::grpc::Entities& grpc_entities, Entities& entities) {
auto grpc_field_size = grpc_entities.fields_size();
std::vector<std::string> field_names(grpc_field_size);
for (int64_t i = 0; i < grpc_field_size; i++) {
field_names[i] = grpc_entities.fields(i).field_name();
}
int row_num = grpc_entities.ids_size();
int64_t offset = 0;
for (int64_t i = 0; i < row_num; i++) {
if (!grpc_entities.valid_row(i)) {
continue;
}
milvus::Entity entity = milvus::Entity();
entity.entity_id = grpc_entities.ids(i);
for (int64_t j = 0; j < grpc_field_size; j++) {
const auto& grpc_field = grpc_entities.fields(j);
auto field_name = grpc_field.field_name();
if (grpc_field.has_attr_record()) {
const auto& grpc_attr_record = grpc_field.attr_record();
if (grpc_attr_record.int32_value_size() > 0) {
entity.scalar_data.insert({field_name, grpc_attr_record.int32_value(offset)});
} else if (grpc_attr_record.int64_value_size() > 0) {
entity.scalar_data.insert({field_name, grpc_attr_record.int64_value(offset)});
} else if (grpc_attr_record.float_value_size() > 0) {
entity.scalar_data.insert({field_name, grpc_attr_record.float_value(offset)});
} else if (grpc_attr_record.double_value_size() > 0) {
entity.scalar_data.insert({field_name, grpc_attr_record.double_value(offset)});
}
} else if (grpc_field.has_vector_record()) {
const auto& grpc_vector_record = grpc_field.vector_record();
const auto& record = grpc_vector_record.records(offset);
milvus::VectorData vector_data;
if (record.float_data_size() > 0) {
std::vector<float> data(record.float_data_size());
memcpy(data.data(), record.float_data().data(), record.float_data_size() * sizeof(float));
vector_data.float_data = data;
} else if (record.binary_data().size() > 0) {
std::vector<uint8_t> data(record.binary_data().size());
memcpy(data.data(), record.binary_data().data(), record.binary_data().size());
vector_data.binary_data = data;
}
entity.vector_data.insert({field_name, vector_data});
}
}
entities.emplace_back(entity);
offset++;
}
}
void
CopyEntityToJson(::milvus::grpc::Entities& grpc_entities, JSON& json_entity) {
int i;
@ -401,29 +436,23 @@ ClientProxy::Disconnect() {
}
Status
ClientProxy::CreateCollection(const Mapping& mapping, const std::string& extra_params) {
ClientProxy::CreateCollection(const Mapping& mapping) {
CLIENT_NULL_CHECK(client_ptr_);
try {
::milvus::grpc::Mapping grpc_mapping;
grpc_mapping.set_collection_name(mapping.collection_name);
for (auto& field : mapping.fields) {
auto grpc_field = grpc_mapping.add_fields();
grpc_field->set_name(field->field_name);
grpc_field->set_type((::milvus::grpc::DataType)field->field_type);
JSON json_index_param = JSON::parse(field->index_params);
for (auto& json_param : json_index_param.items()) {
auto grpc_index_param = grpc_field->add_index_params();
grpc_index_param->set_key(json_param.key());
grpc_index_param->set_value(json_param.value());
}
grpc_field->set_name(field->name);
grpc_field->set_type((::milvus::grpc::DataType)field->type);
auto grpc_extra_param = grpc_field->add_extra_params();
grpc_extra_param->set_key(EXTRA_PARAM_KEY);
grpc_extra_param->set_value(field->extra_params);
grpc_extra_param->set_value(field->params);
}
auto grpc_param = grpc_mapping.add_extra_params();
grpc_param->set_key(EXTRA_PARAM_KEY);
grpc_param->set_value(extra_params);
grpc_param->set_value(mapping.extra_params);
return client_ptr_->CreateCollection(grpc_mapping);
} catch (std::exception& ex) {
@ -485,7 +514,7 @@ ClientProxy::GetCollectionInfo(const std::string& collection_name, Mapping& mapp
for (int64_t i = 0; i < grpc_mapping.fields_size(); i++) {
const auto& grpc_field = grpc_mapping.fields(i);
FieldPtr field_ptr = std::make_shared<Field>();
field_ptr->field_name = grpc_field.name();
field_ptr->name = grpc_field.name();
JSON json_index_params;
for (int64_t j = 0; j < grpc_field.index_params_size(); j++) {
JSON json_param;
@ -499,8 +528,8 @@ ClientProxy::GetCollectionInfo(const std::string& collection_name, Mapping& mapp
json_param = JSON::parse(grpc_field.extra_params(j).value());
json_extra_params.emplace_back(json_param);
}
field_ptr->extra_params = json_extra_params.dump();
field_ptr->field_type = (DataType)grpc_field.type();
field_ptr->params = json_extra_params.dump();
field_ptr->type = (DataType)grpc_field.type();
mapping.fields.emplace_back(field_ptr);
}
if (!grpc_mapping.extra_params().empty()) {
@ -676,7 +705,7 @@ ClientProxy::Insert(const std::string& collection_name, const std::string& parti
Status
ClientProxy::GetEntityByID(const std::string& collection_name, const std::vector<int64_t>& id_array,
std::string& entities) {
Entities& entities) {
CLIENT_NULL_CHECK(client_ptr_);
try {
::milvus::grpc::EntityIdentity entity_identity;
@ -691,9 +720,7 @@ ClientProxy::GetEntityByID(const std::string& collection_name, const std::vector
return status;
}
JSON json_entities;
CopyEntityToJson(grpc_entities, json_entities);
entities = json_entities.dump();
CopyEntities(grpc_entities, entities);
return status;
} catch (std::exception& ex) {
return Status(StatusCode::UnknownError, "Failed to get entity by id: " + std::string(ex.what()));

View File

@ -36,7 +36,7 @@ class ClientProxy : public Connection {
Disconnect() override;
Status
CreateCollection(const Mapping& mapping, const std::string& extra_params) override;
CreateCollection(const Mapping& mapping) override;
Status
DropCollection(const std::string& collection_name) override;
@ -81,7 +81,7 @@ class ClientProxy : public Connection {
Status
GetEntityByID(const std::string& collection_name, const std::vector<int64_t>& id_array,
std::string& entities) override;
Entities& entities) override;
Status
DeleteEntityByID(const std::string& collection_name, const std::vector<int64_t>& id_array) override;

View File

@ -16,7 +16,6 @@
#include <vector>
#include "Status.h"
namespace milvus {
enum class DataType {
@ -40,11 +39,16 @@ enum class DataType {
// Base struct of all fields
struct Field {
uint64_t field_id; ///< read-only
std::string field_name;
DataType field_type;
std::string index_params;
std::string extra_params;
std::string name;
DataType type;
std::string index_params = "{}";
std::string params;
Field(const std::string& _name, const DataType _type, const std::string& _params)
: name(_name), type(_type), params(_params) {
}
Field() = default;
};
using FieldPtr = std::shared_ptr<Field>;
@ -63,4 +67,4 @@ struct VectorField : Field {
};
using VectorFieldPtr = std::shared_ptr<VectorField>;
} // namespace milvus
} // namespace milvus

View File

@ -76,6 +76,13 @@ struct AttrRecord {
/**
* @brief field value
*/
struct Entity {
int64_t entity_id;
std::unordered_map<std::string, std::any> scalar_data;
std::unordered_map<std::string, VectorData> vector_data;
};
using Entities = std::vector<Entity>;
struct FieldValue {
int64_t row_num;
std::unordered_map<std::string, std::vector<int32_t>> int32_value;
@ -99,6 +106,7 @@ struct VectorParam {
struct QueryResult {
std::vector<int64_t> ids; ///< Query entity ids result
std::vector<float> distances; ///< Query distances result
Entities entities;
FieldValue field_value;
};
using TopKQueryResult = std::vector<QueryResult>; ///< Topk hybrid query result
@ -229,7 +237,7 @@ class Connection {
* @return Indicate if collection is created successfully
*/
virtual Status
CreateCollection(const Mapping& mapping, const std::string& extra_params) = 0;
CreateCollection(const Mapping& mapping) = 0;
/**
* @brief Drop collection method
@ -416,7 +424,7 @@ class Connection {
* @return Indicate if the operation is succeed.
*/
virtual Status
GetEntityByID(const std::string& collection_name, const std::vector<int64_t>& id_array, std::string& entities) = 0;
GetEntityByID(const std::string& collection_name, const std::vector<int64_t>& id_array, Entities& entities) = 0;
/**
* @brief Delete entity by id

View File

@ -52,8 +52,8 @@ ConnectionImpl::Disconnect() {
}
Status
ConnectionImpl::CreateCollection(const Mapping& mapping, const std::string& extra_params) {
return client_proxy_->CreateCollection(mapping, extra_params);
ConnectionImpl::CreateCollection(const Mapping& mapping) {
return client_proxy_->CreateCollection(mapping);
}
Status
@ -124,7 +124,7 @@ ConnectionImpl::Insert(const std::string& collection_name, const std::string& pa
Status
ConnectionImpl::GetEntityByID(const std::string& collection_name, const std::vector<int64_t>& id_array,
std::string& entities) {
Entities& entities) {
return client_proxy_->GetEntityByID(collection_name, id_array, entities);
}

View File

@ -53,7 +53,7 @@ class ConnectionImpl : public Connection {
// SetConfig(const std::string& node_name, const std::string& value) const override;
Status
CreateCollection(const Mapping& mapping, const std::string& extra_params) override;
CreateCollection(const Mapping& mapping) override;
Status
DropCollection(const std::string& collection_name) override;
@ -98,7 +98,7 @@ class ConnectionImpl : public Connection {
Status
GetEntityByID(const std::string& collection_name, const std::vector<int64_t>& id_array,
std::string& entities) override;
Entities& entities) override;
Status
DeleteEntityByID(const std::string& collection_name, const std::vector<int64_t>& id_array) override;