mirror of
https://gitee.com/milvus-io/milvus.git
synced 2025-12-30 15:35:33 +08:00
add new api for batch binary search
Former-commit-id: c035ac8fcfe672576c6207fca149c3d4030e7b53
This commit is contained in:
parent
13d7ebb5cf
commit
95f3900910
@ -126,13 +126,13 @@ VecServiceHandler::search_vector(VecSearchResult &_return,
|
||||
const VecTimeRangeList &time_range_list) {
|
||||
TimeRecordWrapper rc("search_vector()");
|
||||
SERVER_LOG_TRACE << "group_id = " << group_id << ", top_k = " << top_k
|
||||
<< ", vector size = " << tensor.tensor.size()
|
||||
<< ", vector dimension = " << tensor.tensor.size()
|
||||
<< ", time range list size = " << time_range_list.range_list.size();
|
||||
|
||||
VecTensorList tensor_list;
|
||||
tensor_list.tensor_list.push_back(tensor);
|
||||
VecSearchResultList result;
|
||||
BaseTaskPtr task_ptr = SearchVectorTask::Create(group_id, top_k, tensor_list, time_range_list, result);
|
||||
BaseTaskPtr task_ptr = SearchVectorTask::Create(group_id, top_k, &tensor_list, time_range_list, result);
|
||||
VecServiceScheduler& scheduler = VecServiceScheduler::GetInstance();
|
||||
scheduler.ExecuteTask(task_ptr);
|
||||
|
||||
@ -154,7 +154,48 @@ VecServiceHandler::search_vector_batch(VecSearchResultList &_return,
|
||||
<< ", vector list size = " << tensor_list.tensor_list.size()
|
||||
<< ", time range list size = " << time_range_list.range_list.size();
|
||||
|
||||
BaseTaskPtr task_ptr = SearchVectorTask::Create(group_id, top_k, tensor_list, time_range_list, _return);
|
||||
BaseTaskPtr task_ptr = SearchVectorTask::Create(group_id, top_k, &tensor_list, time_range_list, _return);
|
||||
VecServiceScheduler& scheduler = VecServiceScheduler::GetInstance();
|
||||
scheduler.ExecuteTask(task_ptr);
|
||||
}
|
||||
|
||||
void
|
||||
VecServiceHandler::search_binary_vector(VecSearchResult& _return,
|
||||
const std::string& group_id,
|
||||
const int64_t top_k,
|
||||
const VecBinaryTensor& tensor,
|
||||
const VecTimeRangeList& time_range_list) {
|
||||
TimeRecordWrapper rc("search_binary_vector()");
|
||||
SERVER_LOG_TRACE << "group_id = " << group_id << ", top_k = " << top_k
|
||||
<< ", vector dimension = " << tensor.tensor.size()
|
||||
<< ", time range list size = " << time_range_list.range_list.size();
|
||||
|
||||
VecBinaryTensorList tensor_list;
|
||||
tensor_list.tensor_list.push_back(tensor);
|
||||
VecSearchResultList result;
|
||||
BaseTaskPtr task_ptr = SearchVectorTask::Create(group_id, top_k, &tensor_list, time_range_list, result);
|
||||
VecServiceScheduler& scheduler = VecServiceScheduler::GetInstance();
|
||||
scheduler.ExecuteTask(task_ptr);
|
||||
|
||||
if(!result.result_list.empty()) {
|
||||
_return = result.result_list[0];
|
||||
} else {
|
||||
SERVER_LOG_ERROR << "No search result returned";
|
||||
}
|
||||
}
|
||||
|
||||
void
|
||||
VecServiceHandler::search_binary_vector_batch(VecSearchResultList& _return,
|
||||
const std::string& group_id,
|
||||
const int64_t top_k,
|
||||
const VecBinaryTensorList& tensor_list,
|
||||
const VecTimeRangeList& time_range_list) {
|
||||
TimeRecordWrapper rc("search_binary_vector_batch()");
|
||||
SERVER_LOG_TRACE << "group_id = " << group_id << ", top_k = " << top_k
|
||||
<< ", vector list size = " << tensor_list.tensor_list.size()
|
||||
<< ", time range list size = " << time_range_list.range_list.size();
|
||||
|
||||
BaseTaskPtr task_ptr = SearchVectorTask::Create(group_id, top_k, &tensor_list, time_range_list, _return);
|
||||
VecServiceScheduler& scheduler = VecServiceScheduler::GetInstance();
|
||||
scheduler.ExecuteTask(task_ptr);
|
||||
}
|
||||
|
||||
@ -67,6 +67,9 @@ public:
|
||||
|
||||
void search_vector_batch(VecSearchResultList& _return, const std::string& group_id, const int64_t top_k, const VecTensorList& tensor_list, const VecTimeRangeList& time_range_list);
|
||||
|
||||
void search_binary_vector(VecSearchResult& _return, const std::string& group_id, const int64_t top_k, const VecBinaryTensor& tensor, const VecTimeRangeList& time_range_list);
|
||||
|
||||
void search_binary_vector_batch(VecSearchResultList& _return, const std::string& group_id, const int64_t top_k, const VecBinaryTensorList& tensor_list, const VecTimeRangeList& time_range_list);
|
||||
};
|
||||
|
||||
|
||||
|
||||
@ -301,7 +301,7 @@ uint64_t AddBatchVectorTask::GetVecDimension(uint64_t index) const {
|
||||
if(index >= bin_tensor_list_->tensor_list.size()){
|
||||
return 0;
|
||||
}
|
||||
return (uint64_t) bin_tensor_list_->tensor_list[index].tensor.size();
|
||||
return (uint64_t) bin_tensor_list_->tensor_list[index].tensor.size()/8;
|
||||
} else {
|
||||
return 0;
|
||||
}
|
||||
@ -341,8 +341,6 @@ std::string AddBatchVectorTask::GetVecID(uint64_t index) const {
|
||||
|
||||
ServerError AddBatchVectorTask::OnExecute() {
|
||||
try {
|
||||
TimeRecorder rc("Add vector batch");
|
||||
|
||||
engine::meta::GroupSchema group_info;
|
||||
group_info.group_id = group_id_;
|
||||
engine::Status stat = DB()->get_group(group_info);
|
||||
@ -351,6 +349,7 @@ ServerError AddBatchVectorTask::OnExecute() {
|
||||
return SERVER_UNEXPECTED_ERROR;
|
||||
}
|
||||
|
||||
TimeRecorder rc("Add vector batch");
|
||||
uint64_t group_dim = group_info.dimension;
|
||||
uint64_t vec_count = GetVecListCount();
|
||||
std::vector<float> vec_f;
|
||||
@ -402,13 +401,29 @@ ServerError AddBatchVectorTask::OnExecute() {
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
SearchVectorTask::SearchVectorTask(const std::string& group_id,
|
||||
const int64_t top_k,
|
||||
const VecTensorList& tensor_list,
|
||||
const VecTensorList* tensor_list,
|
||||
const VecTimeRangeList& time_range_list,
|
||||
VecSearchResultList& result)
|
||||
: BaseTask(DQL_TASK_GROUP),
|
||||
group_id_(group_id),
|
||||
top_k_(top_k),
|
||||
tensor_list_(tensor_list),
|
||||
bin_tensor_list_(nullptr),
|
||||
time_range_list_(time_range_list),
|
||||
result_(result) {
|
||||
|
||||
}
|
||||
|
||||
SearchVectorTask::SearchVectorTask(const std::string& group_id,
|
||||
const int64_t top_k,
|
||||
const VecBinaryTensorList* bin_tensor_list,
|
||||
const VecTimeRangeList& time_range_list,
|
||||
VecSearchResultList& result)
|
||||
: BaseTask(DQL_TASK_GROUP),
|
||||
group_id_(group_id),
|
||||
top_k_(top_k),
|
||||
tensor_list_(nullptr),
|
||||
bin_tensor_list_(bin_tensor_list),
|
||||
time_range_list_(time_range_list),
|
||||
result_(result) {
|
||||
|
||||
@ -416,21 +431,101 @@ SearchVectorTask::SearchVectorTask(const std::string& group_id,
|
||||
|
||||
BaseTaskPtr SearchVectorTask::Create(const std::string& group_id,
|
||||
const int64_t top_k,
|
||||
const VecTensorList& tensor_list,
|
||||
const VecTensorList* tensor_list,
|
||||
const VecTimeRangeList& time_range_list,
|
||||
VecSearchResultList& result) {
|
||||
return std::shared_ptr<BaseTask>(new SearchVectorTask(group_id, top_k, tensor_list, time_range_list, result));
|
||||
}
|
||||
|
||||
BaseTaskPtr SearchVectorTask::Create(const std::string& group_id,
|
||||
const int64_t top_k,
|
||||
const VecBinaryTensorList* bin_tensor_list,
|
||||
const VecTimeRangeList& time_range_list,
|
||||
VecSearchResultList& result) {
|
||||
return std::shared_ptr<BaseTask>(new SearchVectorTask(group_id, top_k, bin_tensor_list, time_range_list, result));
|
||||
}
|
||||
|
||||
|
||||
ServerError SearchVectorTask::GetTargetData(std::vector<float>& data) const {
|
||||
if(tensor_list_ && !tensor_list_->tensor_list.empty()) {
|
||||
uint64_t count = tensor_list_->tensor_list.size();
|
||||
uint64_t dim = tensor_list_->tensor_list[0].tensor.size();
|
||||
data.resize(count*dim);
|
||||
for(size_t i = 0; i < count; i++) {
|
||||
if(tensor_list_->tensor_list[i].tensor.size() != dim) {
|
||||
SERVER_LOG_ERROR << "Invalid vector dimension: " << tensor_list_->tensor_list[i].tensor.size();
|
||||
return SERVER_INVALID_ARGUMENT;
|
||||
}
|
||||
const double* d_p = tensor_list_->tensor_list[i].tensor.data();
|
||||
for(int64_t k = 0; k < dim; k++) {
|
||||
data[i*dim + k] = (float)(d_p[k]);
|
||||
}
|
||||
}
|
||||
} else if(bin_tensor_list_ && !bin_tensor_list_->tensor_list.empty()) {
|
||||
uint64_t count = bin_tensor_list_->tensor_list.size();
|
||||
uint64_t dim = bin_tensor_list_->tensor_list[0].tensor.size()/8;
|
||||
data.resize(count*dim);
|
||||
for(size_t i = 0; i < count; i++) {
|
||||
if(bin_tensor_list_->tensor_list[i].tensor.size()/8 != dim) {
|
||||
SERVER_LOG_ERROR << "Invalid vector dimension: " << bin_tensor_list_->tensor_list[i].tensor.size()/8;
|
||||
return SERVER_INVALID_ARGUMENT;
|
||||
}
|
||||
const double* d_p = (const double*)(bin_tensor_list_->tensor_list[i].tensor.data());
|
||||
for(int64_t k = 0; k < dim; k++) {
|
||||
data[i*dim + k] = (float)(d_p[k]);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return SERVER_SUCCESS;
|
||||
}
|
||||
|
||||
uint64_t SearchVectorTask::GetTargetDimension() const {
|
||||
if(tensor_list_ && !tensor_list_->tensor_list.empty()) {
|
||||
return tensor_list_->tensor_list[0].tensor.size();
|
||||
} else if(bin_tensor_list_ && !bin_tensor_list_->tensor_list.empty()) {
|
||||
return bin_tensor_list_->tensor_list[0].tensor.size()/8;
|
||||
}
|
||||
|
||||
return 0;
|
||||
}
|
||||
|
||||
uint64_t SearchVectorTask::GetTargetCount() const {
|
||||
if(tensor_list_) {
|
||||
return tensor_list_->tensor_list.size();
|
||||
} else if(bin_tensor_list_) {
|
||||
return bin_tensor_list_->tensor_list.size();
|
||||
}
|
||||
}
|
||||
|
||||
ServerError SearchVectorTask::OnExecute() {
|
||||
try {
|
||||
std::vector<float> vec_f;
|
||||
for(const VecTensor& tensor : tensor_list_.tensor_list) {
|
||||
vec_f.insert(vec_f.begin(), tensor.tensor.begin(), tensor.tensor.end());
|
||||
engine::meta::GroupSchema group_info;
|
||||
group_info.group_id = group_id_;
|
||||
engine::Status stat = DB()->get_group(group_info);
|
||||
if(!stat.ok()) {
|
||||
SERVER_LOG_ERROR << "Engine failed: " << stat.ToString();
|
||||
return SERVER_UNEXPECTED_ERROR;
|
||||
}
|
||||
|
||||
uint64_t vec_dim = GetTargetDimension();
|
||||
if(vec_dim != group_info.dimension) {
|
||||
SERVER_LOG_ERROR << "Invalid vector dimension: " << vec_dim
|
||||
<< " vs. group dimension:" << group_info.dimension;
|
||||
return SERVER_INVALID_ARGUMENT;
|
||||
}
|
||||
|
||||
TimeRecorder rc("Search vector");
|
||||
std::vector<float> vec_f;
|
||||
ServerError err = GetTargetData(vec_f);
|
||||
if(err != SERVER_SUCCESS) {
|
||||
return err;
|
||||
}
|
||||
|
||||
uint64_t vec_count = GetTargetCount();
|
||||
|
||||
engine::QueryResults results;
|
||||
engine::Status stat = DB()->search(group_id_, (size_t)top_k_, tensor_list_.tensor_list.size(), vec_f.data(), results);
|
||||
stat = DB()->search(group_id_, (size_t)top_k_, vec_count, vec_f.data(), results);
|
||||
if(!stat.ok()) {
|
||||
SERVER_LOG_ERROR << "Engine failed: " << stat.ToString();
|
||||
return SERVER_UNEXPECTED_ERROR;
|
||||
@ -438,7 +533,7 @@ ServerError SearchVectorTask::OnExecute() {
|
||||
for(engine::QueryResult& res : results){
|
||||
VecSearchResult v_res;
|
||||
std::string nid_prefix = group_id_ + "_";
|
||||
for(auto id : results[0]) {
|
||||
for(auto id : res) {
|
||||
std::string sid;
|
||||
std::string nid = nid_prefix + std::to_string(id);
|
||||
IVecIdMapper::GetInstance()->Get(nid, sid);
|
||||
|
||||
@ -129,24 +129,40 @@ class SearchVectorTask : public BaseTask {
|
||||
public:
|
||||
static BaseTaskPtr Create(const std::string& group_id,
|
||||
const int64_t top_k,
|
||||
const VecTensorList& tensor_list,
|
||||
const VecTensorList* tensor_list,
|
||||
const VecTimeRangeList& time_range_list,
|
||||
VecSearchResultList& result);
|
||||
|
||||
static BaseTaskPtr Create(const std::string& group_id,
|
||||
const int64_t top_k,
|
||||
const VecBinaryTensorList* bin_tensor_list,
|
||||
const VecTimeRangeList& time_range_list,
|
||||
VecSearchResultList& result);
|
||||
|
||||
protected:
|
||||
SearchVectorTask(const std::string& group_id,
|
||||
const int64_t top_k,
|
||||
const VecTensorList& tensor_list,
|
||||
const VecTensorList* tensor_list,
|
||||
const VecTimeRangeList& time_range_list,
|
||||
VecSearchResultList& result);
|
||||
|
||||
ServerError OnExecute() override;
|
||||
SearchVectorTask(const std::string& group_id,
|
||||
const int64_t top_k,
|
||||
const VecBinaryTensorList* bin_tensor_list,
|
||||
const VecTimeRangeList& time_range_list,
|
||||
VecSearchResultList& result);
|
||||
|
||||
ServerError GetTargetData(std::vector<float>& data) const;
|
||||
uint64_t GetTargetDimension() const;
|
||||
uint64_t GetTargetCount() const;
|
||||
|
||||
ServerError OnExecute() override;
|
||||
|
||||
private:
|
||||
std::string group_id_;
|
||||
int64_t top_k_;
|
||||
const VecTensorList& tensor_list_;
|
||||
const VecTensorList* tensor_list_;
|
||||
const VecBinaryTensorList* bin_tensor_list_;
|
||||
const VecTimeRangeList& time_range_list_;
|
||||
VecSearchResultList& result_;
|
||||
};
|
||||
|
||||
@ -57,7 +57,7 @@ void ClientApp::Run(const std::string &config_file) {
|
||||
session.interface()->add_group(group);
|
||||
|
||||
//prepare data
|
||||
const int64_t count = 100000;
|
||||
const int64_t count = 10000;
|
||||
VecTensorList tensor_list;
|
||||
VecBinaryTensorList bin_tensor_list;
|
||||
for (int64_t k = 0; k < count; k++) {
|
||||
@ -79,36 +79,36 @@ void ClientApp::Run(const std::string &config_file) {
|
||||
bin_tensor_list.tensor_list.emplace_back(bin_tensor);
|
||||
}
|
||||
|
||||
// //add vectors one by one
|
||||
// {
|
||||
// server::TimeRecorder rc("Add " + std::to_string(count) + " vectors one by one");
|
||||
// for (int64_t k = 0; k < count; k++) {
|
||||
// session.interface()->add_vector(group.id, tensor_list.tensor_list[k]);
|
||||
// if(k%1000 == 0) {
|
||||
// CLIENT_LOG_INFO << "add normal vector no." << k;
|
||||
// }
|
||||
// }
|
||||
// rc.Elapse("done!");
|
||||
// }
|
||||
//
|
||||
// //add vectors in one batch
|
||||
// {
|
||||
// server::TimeRecorder rc("Add " + std::to_string(count) + " vectors in one batch");
|
||||
// session.interface()->add_vector_batch(group.id, tensor_list);
|
||||
// rc.Elapse("done!");
|
||||
// }
|
||||
//
|
||||
// //add binary vectors one by one
|
||||
// {
|
||||
// server::TimeRecorder rc("Add " + std::to_string(count) + " binary vectors one by one");
|
||||
// for (int64_t k = 0; k < count; k++) {
|
||||
// session.interface()->add_binary_vector(group.id, bin_tensor_list.tensor_list[k]);
|
||||
// if(k%1000 == 0) {
|
||||
// CLIENT_LOG_INFO << "add binary vector no." << k;
|
||||
// }
|
||||
// }
|
||||
// rc.Elapse("done!");
|
||||
// }
|
||||
//add vectors one by one
|
||||
{
|
||||
server::TimeRecorder rc("Add " + std::to_string(count) + " vectors one by one");
|
||||
for (int64_t k = 0; k < count; k++) {
|
||||
session.interface()->add_vector(group.id, tensor_list.tensor_list[k]);
|
||||
if(k%1000 == 0) {
|
||||
CLIENT_LOG_INFO << "add normal vector no." << k;
|
||||
}
|
||||
}
|
||||
rc.Elapse("done!");
|
||||
}
|
||||
|
||||
//add vectors in one batch
|
||||
{
|
||||
server::TimeRecorder rc("Add " + std::to_string(count) + " vectors in one batch");
|
||||
session.interface()->add_vector_batch(group.id, tensor_list);
|
||||
rc.Elapse("done!");
|
||||
}
|
||||
|
||||
//add binary vectors one by one
|
||||
{
|
||||
server::TimeRecorder rc("Add " + std::to_string(count) + " binary vectors one by one");
|
||||
for (int64_t k = 0; k < count; k++) {
|
||||
session.interface()->add_binary_vector(group.id, bin_tensor_list.tensor_list[k]);
|
||||
if(k%1000 == 0) {
|
||||
CLIENT_LOG_INFO << "add binary vector no." << k;
|
||||
}
|
||||
}
|
||||
rc.Elapse("done!");
|
||||
}
|
||||
|
||||
//add binary vectors in one batch
|
||||
{
|
||||
@ -134,11 +134,41 @@ void ClientApp::Run(const std::string &config_file) {
|
||||
|
||||
std::cout << "Search result: " << std::endl;
|
||||
for(auto id : res.id_list) {
|
||||
std::cout << id << std::endl;
|
||||
std::cout << "\t" << id << std::endl;
|
||||
}
|
||||
rc.Elapse("done!");
|
||||
}
|
||||
|
||||
//search binary vector
|
||||
{
|
||||
server::TimeRecorder rc("Search binary batch top_k");
|
||||
VecBinaryTensorList tensor_list;
|
||||
for(int32_t k = 350; k < 360; k++) {
|
||||
VecBinaryTensor bin_tensor;
|
||||
bin_tensor.tensor.resize(dim * sizeof(double));
|
||||
double* d_p = new double[dim];
|
||||
for (int32_t i = 0; i < dim; i++) {
|
||||
d_p[i] = (double)(i + k);
|
||||
}
|
||||
memcpy(const_cast<char*>(bin_tensor.tensor.data()), d_p, dim * sizeof(double));
|
||||
tensor_list.tensor_list.emplace_back(bin_tensor);
|
||||
}
|
||||
|
||||
VecSearchResultList res;
|
||||
VecTimeRangeList range;
|
||||
session.interface()->search_binary_vector_batch(res, group.id, 5, tensor_list, range);
|
||||
|
||||
std::cout << "Search binary batch result: " << std::endl;
|
||||
for(size_t i = 0 ; i < res.result_list.size(); i++) {
|
||||
std::cout << "No " << i << ":" << std::endl;
|
||||
for(auto id : res.result_list[i].id_list) {
|
||||
std::cout << "\t" << id << std::endl;
|
||||
}
|
||||
}
|
||||
|
||||
rc.Elapse("done!");
|
||||
}
|
||||
|
||||
} catch (std::exception& ex) {
|
||||
CLIENT_LOG_ERROR << "request encounter exception: " << ex.what();
|
||||
}
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user