#include "ClientV2.h" #include "pulsar/Result.h" #include "PartitionPolicy.h" #include "utils/CommonUtil.h" #include "M3_hash.h" #include "config/ServerConfig.h" #include #include #include #include #include "nlohmann/json.hpp" #include "log/Log.h" namespace milvus::message_client { std::map>> total_results; MsgClientV2::MsgClientV2(int64_t client_id, const std::string &service_url, const uint32_t mut_parallelism, const pulsar::ClientConfiguration &config) : client_id_(client_id), service_url_(service_url), mut_parallelism_(mut_parallelism) { } Status MsgClientV2::Init(const std::string &insert_delete, const std::string &search, const std::string &time_sync, const std::string &search_by_id, const std::string &search_result) { //create pulsar client std::shared_ptr pulsar_client; if (config.pulsar.authentication) { pulsar::ClientConfiguration clientConfig; clientConfig.setAuth(pulsar::AuthToken::createWithToken(config.pulsar.token.value)); pulsar_client = std::make_shared(service_url_, clientConfig); } else { pulsar_client = std::make_shared(service_url_); } //create pulsar producer ProducerConfiguration producerConfiguration; producerConfiguration.setPartitionsRoutingMode(ProducerConfiguration::CustomPartition); producerConfiguration.setMessageRouter(std::make_shared()); // insert_delete_producer_ = std::make_shared(pulsar_client, insert_delete, producerConfiguration); search_producer_ = std::make_shared(pulsar_client, search, producerConfiguration); search_by_id_producer_ = std::make_shared(pulsar_client, search_by_id, producerConfiguration); time_sync_producer_ = std::make_shared(pulsar_client, time_sync); for (auto i = 0; i < mut_parallelism_; i++) { std::string insert_or_delete_topic = insert_delete + "-" + std::to_string(i); paralle_mut_producers_.emplace_back(std::make_shared(pulsar_client, insert_or_delete_topic, producerConfiguration)); } //create pulsar consumer std::string subscribe_name = std::to_string(CommonUtil::RandomUINT64()); consumer_ = std::make_shared(pulsar_client, search_result + subscribe_name); std::string search_topic = search_result + "-" + std::to_string(config.proxy_id()); auto result = consumer_->subscribe(search_topic); if (result != pulsar::Result::ResultOk) { return Status(SERVER_UNEXPECTED_ERROR, "Pulsar message client init occur error, " + std::string(pulsar::strResult(result))); } return Status::OK(); } int64_t GetQueryNodeNum() { return config.master.query_node_num(); } Status Aggregation(std::vector> results, milvus::grpc::QueryResult *result) { if (results.empty()) { return Status(DB_ERROR, "The result is null!"); } std::vector all_scores; // Proxy get numQueries from row_num. auto numQueries = results[0]->row_num(); auto topK = results[0]->distances_size() / numQueries; // 2d array for multiple queries std::vector> all_distance(numQueries); std::vector> all_entities_ids(numQueries); std::vector all_valid_row; std::vector all_row_data; std::vector all_kv_pairs; grpc::Status status; // int row_num = 0; for (auto &result_per_node : results) { if (result_per_node->status().error_code() != grpc::ErrorCode::SUCCESS) { // if (one_node_res->status().error_code() != grpc::ErrorCode::SUCCESS || // one_node_res->entities().status().error_code() != grpc::ErrorCode::SUCCESS) { return Status(DB_ERROR, "QueryNode return wrong status!"); } // assert(result_per_node->row_num() == numQueries); for (int i = 0; i < numQueries; i++) { for (int j = i * topK; j < (i + 1) * topK && j < result_per_node->distances_size(); j++) { all_scores.push_back(result_per_node->scores()[j]); all_distance[i].push_back(result_per_node->distances()[j]); all_entities_ids[i].push_back(result_per_node->entities().ids(j)); } } // for (int j = 0; j < result_per_node->distances_size(); j++) { // all_scores.push_back(result_per_node->scores()[j]); // all_distance.push_back(result_per_node->distances()[j]); //// all_kv_pairs.push_back(result_per_node->extra_params()[j]); // } // for (int k = 0; k < result_per_node->entities().ids_size(); ++k) { // all_entities_ids.push_back(result_per_node->entities().ids(k)); //// all_valid_row.push_back(result_per_node->entities().valid_row(k)); //// all_row_data.push_back(result_per_node->entities().rows_data(k)); // } // if (result_per_node->row_num() > row_num) { // row_num = result_per_node->row_num(); // } status = result_per_node->status(); } std::vector> index_array; for (int i = 0; i < numQueries; i++) { auto &distance = all_distance[i]; std::vector index(distance.size()); iota(index.begin(), index.end(), 0); std::stable_sort(index.begin(), index.end(), [&distance](size_t i1, size_t i2) { return distance[i1] < distance[i2]; }); index_array.emplace_back(index); } grpc::Entities result_entities; for (int i = 0; i < numQueries; i++) { for (int m = 0; m < topK; ++m) { result->add_scores(all_scores[index_array[i][m]]); result->add_distances(all_distance[i][index_array[i][m]]); // result->add_extra_params(); // result->mutable_extra_params(m)->CopyFrom(all_kv_pairs[index[m]]); result_entities.add_ids(all_entities_ids[i][index_array[i][m]]); // result_entities.add_valid_row(all_valid_row[index[m]]); // result_entities.add_rows_data(); // result_entities.mutable_rows_data(m)->CopyFrom(all_row_data[index[m]]); } } result_entities.mutable_status()->CopyFrom(status); result->set_row_num(numQueries); result->mutable_entities()->CopyFrom(result_entities); result->set_query_id(results[0]->query_id()); // result->set_client_id(results[0]->client_id()); return Status::OK(); } Status MsgClientV2::GetQueryResult(int64_t query_id, milvus::grpc::QueryResult *result) { int64_t query_node_num = GetQueryNodeNum(); auto t1 = std::chrono::high_resolution_clock::now(); while (true) { auto received_result = total_results[query_id]; if (received_result.size() == query_node_num) { break; } Message msg; consumer_->receive(msg); grpc::QueryResult search_res_msg; auto status = search_res_msg.ParseFromString(msg.getDataAsString()); if (status) { auto message = std::make_shared(search_res_msg); if (message->status().error_code() != grpc::ErrorCode::SUCCESS) { consumer_->acknowledge(msg); return Status(DB_ERROR, "Search Failed"); } total_results[message->query_id()].push_back(message); consumer_->acknowledge(msg); } else { consumer_->acknowledge(msg); return Status(DB_ERROR, "can't parse message which from pulsar!"); } } auto status = Aggregation(total_results[query_id], result); return status; } Status MsgClientV2::SendMutMessage(const milvus::grpc::InsertParam &request, uint64_t timestamp, const std::function &segment_id) { const uint64_t num_records_log = 100 * 10000; static uint64_t num_inserted = 0; static uint64_t size_inserted = 0; using stdclock = std::chrono::high_resolution_clock; static stdclock::duration time_cost; auto start = stdclock::now(); // may have retry policy? auto row_count = request.rows_data_size(); auto topic_num = config.pulsar.topicnum(); auto stats = std::vector(topic_num); std::atomic_uint64_t msg_sended = 0; #pragma omp parallel for default(none), shared(row_count, request, timestamp, stats, segment_id, msg_sended, topic_num), num_threads(topic_num) for (auto i = 0; i < row_count; i++) { milvus::grpc::InsertOrDeleteMsg mut_msg; int this_thread = omp_get_thread_num(); mut_msg.set_op(milvus::grpc::OpType::INSERT); mut_msg.set_uid(request.entity_id_array(i)); mut_msg.set_client_id(client_id_); mut_msg.set_timestamp(timestamp); mut_msg.set_collection_name(request.collection_name()); mut_msg.set_partition_tag(request.partition_tag()); uint64_t uid = request.entity_id_array(i); // auto channel_id = makeHash(&uid, sizeof(uint64_t)) % topic_num; //TODO:: don't prove the correction auto channel_id = this_thread; try { mut_msg.set_segment_id(segment_id(request.collection_name(), channel_id, timestamp)); mut_msg.mutable_rows_data()->CopyFrom(request.rows_data(i)); mut_msg.mutable_extra_params()->CopyFrom(request.extra_params()); mut_msg.set_channel_id(channel_id); auto callback = [&stats, &msg_sended, this_thread](Result result, const pulsar::MessageId &messageId) { msg_sended += 1; if (result != pulsar::ResultOk) { stats[this_thread] = Status(DB_ERROR, pulsar::strResult(result)); } }; paralle_mut_producers_[channel_id]->sendAsync(mut_msg, callback); } catch (const std::exception &e) { msg_sended += 1; stats[this_thread] = Status(DB_ERROR, e.what()); } } while (msg_sended < row_count) { } auto end = stdclock::now(); time_cost += (end - start); num_inserted += row_count; size_inserted += request.ByteSize(); if (num_inserted >= num_records_log) { // char buff[128]; // auto r = getcwd(buff, 128); auto path = std::string("/tmp"); std::ofstream file(path + "/proxy2pulsar.benchmark", std::fstream::app); nlohmann::json json; json["InsertTime"] = milvus::CommonUtil::TimeToString(start); json["DurationInMilliseconds"] = std::chrono::duration_cast(time_cost).count(); json["SizeInMB"] = size_inserted / 1024.0 / 1024.0; json["ThroughputInMB"] = double(size_inserted) / std::chrono::duration_cast(time_cost).count() * 1000 / 1024.0 / 1024; json["NumRecords"] = num_inserted; file << json.dump() << std::endl; /* file << "[" << milvus::CommonUtil::TimeToString(start) << "]" << " Insert " << num_inserted << " records, " << "size:" << size_inserted / 1024.0 / 1024.0 << "M, " << "cost" << std::chrono::duration_cast(time_cost).count() / 1000.0 << "s, " << "throughput: " << double(size_inserted) / std::chrono::duration_cast(time_cost).count() * 1000 / 1024.0 / 1024 << "M/s" << std::endl; */ time_cost = stdclock::duration(0); num_inserted = 0; size_inserted = 0; } for (auto &stat : stats) { if (!stat.ok()) { return stat; } } return Status::OK(); } Status MsgClientV2::SendMutMessage(const milvus::grpc::DeleteByIDParam &request, uint64_t timestamp, const std::function &segment_id) { using stdclock = std::chrono::high_resolution_clock; auto start = stdclock::now(); auto row_count = request.id_array_size(); auto topicnum = config.pulsar.topicnum(); auto stats = std::vector(topicnum); std::atomic_uint64_t msg_sended = 0; #pragma omp parallel for default(none), shared( request, timestamp, stats, segment_id, msg_sended, row_count, topicnum), num_threads(topicnum) for (auto i = 0; i < row_count; i++) { milvus::grpc::InsertOrDeleteMsg mut_msg; mut_msg.set_op(milvus::grpc::OpType::DELETE); mut_msg.set_client_id(client_id_); mut_msg.set_uid(request.id_array(i)); mut_msg.set_collection_name(request.collection_name()); mut_msg.set_timestamp(timestamp); int this_thread = omp_get_thread_num(); auto callback = [&stats, &msg_sended, this_thread](Result result, const pulsar::MessageId &messageId) { msg_sended += 1; if (result != pulsar::ResultOk) { stats[this_thread] = Status(DB_ERROR, pulsar::strResult(result)); } }; paralle_mut_producers_[this_thread]->sendAsync(mut_msg, callback); } while (msg_sended < row_count) { } auto end = stdclock::now(); auto data_size = request.ByteSize(); LOG_SERVER_INFO_ << "InsertReq Batch size:" << data_size / 1024.0 / 1024.0 << "M, " << "throughput: " << data_size / std::chrono::duration_cast(end - start).count() * 1000 / 1024.0 / 1024 << "M/s"; for (auto &stat : stats) { if (!stat.ok()) { return stat; } } return Status::OK(); } Status MsgClientV2::SendQueryMessage(const milvus::grpc::SearchParam &request, uint64_t timestamp, int64_t &query_id) { milvus::grpc::SearchMsg search_msg; query_id = GetUniqueQId(); search_msg.set_collection_name(request.collection_name()); search_msg.set_uid(query_id); //TODO: get client id from master search_msg.set_client_id(client_id_); search_msg.set_timestamp(timestamp); search_msg.set_dsl(request.dsl()); milvus::grpc::VectorRowRecord vector_row_recode; std::vector vectors_records; std::string binary_data; for (int i = 0; i < request.vector_param_size(); ++i) { search_msg.add_json(request.vector_param(i).json()); for (int j = 0; j < request.vector_param(i).row_record().records_size(); ++j) { for (int k = 0; k < request.vector_param(i).row_record().records(j).float_data_size(); ++k) { vector_row_recode.add_float_data(request.vector_param(i).row_record().records(j).float_data(k)); } binary_data.append(request.vector_param(i).row_record().records(j).binary_data()); } } vector_row_recode.set_binary_data(binary_data); search_msg.mutable_records()->CopyFrom(vector_row_recode); for (int m = 0; m < request.partition_tag_size(); ++m) { search_msg.add_partition_tag(request.partition_tag(m)); } for (int l = 0; l < request.extra_params_size(); ++l) { search_msg.mutable_extra_params(l)->CopyFrom(request.extra_params(l)); } auto result = search_producer_->send(search_msg); if (result != pulsar::Result::ResultOk) { return Status(DB_ERROR, pulsar::strResult(result)); } return Status::OK(); } MsgClientV2::~MsgClientV2() { // insert_delete_producer_->close(); for (auto &producer: paralle_mut_producers_) { producer->close(); } search_producer_->close(); search_by_id_producer_->close(); time_sync_producer_->close(); consumer_->close(); } }