add new api for batch binary search

Former-commit-id: c035ac8fcfe672576c6207fca149c3d4030e7b53
This commit is contained in:
groot 2019-04-28 12:07:59 +08:00
parent 13d7ebb5cf
commit 95f3900910
5 changed files with 234 additions and 49 deletions

View File

@ -126,13 +126,13 @@ VecServiceHandler::search_vector(VecSearchResult &_return,
const VecTimeRangeList &time_range_list) {
TimeRecordWrapper rc("search_vector()");
SERVER_LOG_TRACE << "group_id = " << group_id << ", top_k = " << top_k
<< ", vector size = " << tensor.tensor.size()
<< ", vector dimension = " << tensor.tensor.size()
<< ", time range list size = " << time_range_list.range_list.size();
VecTensorList tensor_list;
tensor_list.tensor_list.push_back(tensor);
VecSearchResultList result;
BaseTaskPtr task_ptr = SearchVectorTask::Create(group_id, top_k, tensor_list, time_range_list, result);
BaseTaskPtr task_ptr = SearchVectorTask::Create(group_id, top_k, &tensor_list, time_range_list, result);
VecServiceScheduler& scheduler = VecServiceScheduler::GetInstance();
scheduler.ExecuteTask(task_ptr);
@ -154,7 +154,48 @@ VecServiceHandler::search_vector_batch(VecSearchResultList &_return,
<< ", vector list size = " << tensor_list.tensor_list.size()
<< ", time range list size = " << time_range_list.range_list.size();
BaseTaskPtr task_ptr = SearchVectorTask::Create(group_id, top_k, tensor_list, time_range_list, _return);
BaseTaskPtr task_ptr = SearchVectorTask::Create(group_id, top_k, &tensor_list, time_range_list, _return);
VecServiceScheduler& scheduler = VecServiceScheduler::GetInstance();
scheduler.ExecuteTask(task_ptr);
}
void
VecServiceHandler::search_binary_vector(VecSearchResult& _return,
const std::string& group_id,
const int64_t top_k,
const VecBinaryTensor& tensor,
const VecTimeRangeList& time_range_list) {
TimeRecordWrapper rc("search_binary_vector()");
SERVER_LOG_TRACE << "group_id = " << group_id << ", top_k = " << top_k
<< ", vector dimension = " << tensor.tensor.size()
<< ", time range list size = " << time_range_list.range_list.size();
VecBinaryTensorList tensor_list;
tensor_list.tensor_list.push_back(tensor);
VecSearchResultList result;
BaseTaskPtr task_ptr = SearchVectorTask::Create(group_id, top_k, &tensor_list, time_range_list, result);
VecServiceScheduler& scheduler = VecServiceScheduler::GetInstance();
scheduler.ExecuteTask(task_ptr);
if(!result.result_list.empty()) {
_return = result.result_list[0];
} else {
SERVER_LOG_ERROR << "No search result returned";
}
}
void
VecServiceHandler::search_binary_vector_batch(VecSearchResultList& _return,
const std::string& group_id,
const int64_t top_k,
const VecBinaryTensorList& tensor_list,
const VecTimeRangeList& time_range_list) {
TimeRecordWrapper rc("search_binary_vector_batch()");
SERVER_LOG_TRACE << "group_id = " << group_id << ", top_k = " << top_k
<< ", vector list size = " << tensor_list.tensor_list.size()
<< ", time range list size = " << time_range_list.range_list.size();
BaseTaskPtr task_ptr = SearchVectorTask::Create(group_id, top_k, &tensor_list, time_range_list, _return);
VecServiceScheduler& scheduler = VecServiceScheduler::GetInstance();
scheduler.ExecuteTask(task_ptr);
}

View File

@ -67,6 +67,9 @@ public:
void search_vector_batch(VecSearchResultList& _return, const std::string& group_id, const int64_t top_k, const VecTensorList& tensor_list, const VecTimeRangeList& time_range_list);
void search_binary_vector(VecSearchResult& _return, const std::string& group_id, const int64_t top_k, const VecBinaryTensor& tensor, const VecTimeRangeList& time_range_list);
void search_binary_vector_batch(VecSearchResultList& _return, const std::string& group_id, const int64_t top_k, const VecBinaryTensorList& tensor_list, const VecTimeRangeList& time_range_list);
};

View File

@ -301,7 +301,7 @@ uint64_t AddBatchVectorTask::GetVecDimension(uint64_t index) const {
if(index >= bin_tensor_list_->tensor_list.size()){
return 0;
}
return (uint64_t) bin_tensor_list_->tensor_list[index].tensor.size();
return (uint64_t) bin_tensor_list_->tensor_list[index].tensor.size()/8;
} else {
return 0;
}
@ -341,8 +341,6 @@ std::string AddBatchVectorTask::GetVecID(uint64_t index) const {
ServerError AddBatchVectorTask::OnExecute() {
try {
TimeRecorder rc("Add vector batch");
engine::meta::GroupSchema group_info;
group_info.group_id = group_id_;
engine::Status stat = DB()->get_group(group_info);
@ -351,6 +349,7 @@ ServerError AddBatchVectorTask::OnExecute() {
return SERVER_UNEXPECTED_ERROR;
}
TimeRecorder rc("Add vector batch");
uint64_t group_dim = group_info.dimension;
uint64_t vec_count = GetVecListCount();
std::vector<float> vec_f;
@ -402,13 +401,29 @@ ServerError AddBatchVectorTask::OnExecute() {
////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////
SearchVectorTask::SearchVectorTask(const std::string& group_id,
const int64_t top_k,
const VecTensorList& tensor_list,
const VecTensorList* tensor_list,
const VecTimeRangeList& time_range_list,
VecSearchResultList& result)
: BaseTask(DQL_TASK_GROUP),
group_id_(group_id),
top_k_(top_k),
tensor_list_(tensor_list),
bin_tensor_list_(nullptr),
time_range_list_(time_range_list),
result_(result) {
}
SearchVectorTask::SearchVectorTask(const std::string& group_id,
const int64_t top_k,
const VecBinaryTensorList* bin_tensor_list,
const VecTimeRangeList& time_range_list,
VecSearchResultList& result)
: BaseTask(DQL_TASK_GROUP),
group_id_(group_id),
top_k_(top_k),
tensor_list_(nullptr),
bin_tensor_list_(bin_tensor_list),
time_range_list_(time_range_list),
result_(result) {
@ -416,21 +431,101 @@ SearchVectorTask::SearchVectorTask(const std::string& group_id,
BaseTaskPtr SearchVectorTask::Create(const std::string& group_id,
const int64_t top_k,
const VecTensorList& tensor_list,
const VecTensorList* tensor_list,
const VecTimeRangeList& time_range_list,
VecSearchResultList& result) {
return std::shared_ptr<BaseTask>(new SearchVectorTask(group_id, top_k, tensor_list, time_range_list, result));
}
BaseTaskPtr SearchVectorTask::Create(const std::string& group_id,
const int64_t top_k,
const VecBinaryTensorList* bin_tensor_list,
const VecTimeRangeList& time_range_list,
VecSearchResultList& result) {
return std::shared_ptr<BaseTask>(new SearchVectorTask(group_id, top_k, bin_tensor_list, time_range_list, result));
}
ServerError SearchVectorTask::GetTargetData(std::vector<float>& data) const {
if(tensor_list_ && !tensor_list_->tensor_list.empty()) {
uint64_t count = tensor_list_->tensor_list.size();
uint64_t dim = tensor_list_->tensor_list[0].tensor.size();
data.resize(count*dim);
for(size_t i = 0; i < count; i++) {
if(tensor_list_->tensor_list[i].tensor.size() != dim) {
SERVER_LOG_ERROR << "Invalid vector dimension: " << tensor_list_->tensor_list[i].tensor.size();
return SERVER_INVALID_ARGUMENT;
}
const double* d_p = tensor_list_->tensor_list[i].tensor.data();
for(int64_t k = 0; k < dim; k++) {
data[i*dim + k] = (float)(d_p[k]);
}
}
} else if(bin_tensor_list_ && !bin_tensor_list_->tensor_list.empty()) {
uint64_t count = bin_tensor_list_->tensor_list.size();
uint64_t dim = bin_tensor_list_->tensor_list[0].tensor.size()/8;
data.resize(count*dim);
for(size_t i = 0; i < count; i++) {
if(bin_tensor_list_->tensor_list[i].tensor.size()/8 != dim) {
SERVER_LOG_ERROR << "Invalid vector dimension: " << bin_tensor_list_->tensor_list[i].tensor.size()/8;
return SERVER_INVALID_ARGUMENT;
}
const double* d_p = (const double*)(bin_tensor_list_->tensor_list[i].tensor.data());
for(int64_t k = 0; k < dim; k++) {
data[i*dim + k] = (float)(d_p[k]);
}
}
}
return SERVER_SUCCESS;
}
uint64_t SearchVectorTask::GetTargetDimension() const {
if(tensor_list_ && !tensor_list_->tensor_list.empty()) {
return tensor_list_->tensor_list[0].tensor.size();
} else if(bin_tensor_list_ && !bin_tensor_list_->tensor_list.empty()) {
return bin_tensor_list_->tensor_list[0].tensor.size()/8;
}
return 0;
}
uint64_t SearchVectorTask::GetTargetCount() const {
if(tensor_list_) {
return tensor_list_->tensor_list.size();
} else if(bin_tensor_list_) {
return bin_tensor_list_->tensor_list.size();
}
}
ServerError SearchVectorTask::OnExecute() {
try {
std::vector<float> vec_f;
for(const VecTensor& tensor : tensor_list_.tensor_list) {
vec_f.insert(vec_f.begin(), tensor.tensor.begin(), tensor.tensor.end());
engine::meta::GroupSchema group_info;
group_info.group_id = group_id_;
engine::Status stat = DB()->get_group(group_info);
if(!stat.ok()) {
SERVER_LOG_ERROR << "Engine failed: " << stat.ToString();
return SERVER_UNEXPECTED_ERROR;
}
uint64_t vec_dim = GetTargetDimension();
if(vec_dim != group_info.dimension) {
SERVER_LOG_ERROR << "Invalid vector dimension: " << vec_dim
<< " vs. group dimension:" << group_info.dimension;
return SERVER_INVALID_ARGUMENT;
}
TimeRecorder rc("Search vector");
std::vector<float> vec_f;
ServerError err = GetTargetData(vec_f);
if(err != SERVER_SUCCESS) {
return err;
}
uint64_t vec_count = GetTargetCount();
engine::QueryResults results;
engine::Status stat = DB()->search(group_id_, (size_t)top_k_, tensor_list_.tensor_list.size(), vec_f.data(), results);
stat = DB()->search(group_id_, (size_t)top_k_, vec_count, vec_f.data(), results);
if(!stat.ok()) {
SERVER_LOG_ERROR << "Engine failed: " << stat.ToString();
return SERVER_UNEXPECTED_ERROR;
@ -438,7 +533,7 @@ ServerError SearchVectorTask::OnExecute() {
for(engine::QueryResult& res : results){
VecSearchResult v_res;
std::string nid_prefix = group_id_ + "_";
for(auto id : results[0]) {
for(auto id : res) {
std::string sid;
std::string nid = nid_prefix + std::to_string(id);
IVecIdMapper::GetInstance()->Get(nid, sid);

View File

@ -129,24 +129,40 @@ class SearchVectorTask : public BaseTask {
public:
static BaseTaskPtr Create(const std::string& group_id,
const int64_t top_k,
const VecTensorList& tensor_list,
const VecTensorList* tensor_list,
const VecTimeRangeList& time_range_list,
VecSearchResultList& result);
static BaseTaskPtr Create(const std::string& group_id,
const int64_t top_k,
const VecBinaryTensorList* bin_tensor_list,
const VecTimeRangeList& time_range_list,
VecSearchResultList& result);
protected:
SearchVectorTask(const std::string& group_id,
const int64_t top_k,
const VecTensorList& tensor_list,
const VecTensorList* tensor_list,
const VecTimeRangeList& time_range_list,
VecSearchResultList& result);
ServerError OnExecute() override;
SearchVectorTask(const std::string& group_id,
const int64_t top_k,
const VecBinaryTensorList* bin_tensor_list,
const VecTimeRangeList& time_range_list,
VecSearchResultList& result);
ServerError GetTargetData(std::vector<float>& data) const;
uint64_t GetTargetDimension() const;
uint64_t GetTargetCount() const;
ServerError OnExecute() override;
private:
std::string group_id_;
int64_t top_k_;
const VecTensorList& tensor_list_;
const VecTensorList* tensor_list_;
const VecBinaryTensorList* bin_tensor_list_;
const VecTimeRangeList& time_range_list_;
VecSearchResultList& result_;
};

View File

@ -57,7 +57,7 @@ void ClientApp::Run(const std::string &config_file) {
session.interface()->add_group(group);
//prepare data
const int64_t count = 100000;
const int64_t count = 10000;
VecTensorList tensor_list;
VecBinaryTensorList bin_tensor_list;
for (int64_t k = 0; k < count; k++) {
@ -79,36 +79,36 @@ void ClientApp::Run(const std::string &config_file) {
bin_tensor_list.tensor_list.emplace_back(bin_tensor);
}
// //add vectors one by one
// {
// server::TimeRecorder rc("Add " + std::to_string(count) + " vectors one by one");
// for (int64_t k = 0; k < count; k++) {
// session.interface()->add_vector(group.id, tensor_list.tensor_list[k]);
// if(k%1000 == 0) {
// CLIENT_LOG_INFO << "add normal vector no." << k;
// }
// }
// rc.Elapse("done!");
// }
//
// //add vectors in one batch
// {
// server::TimeRecorder rc("Add " + std::to_string(count) + " vectors in one batch");
// session.interface()->add_vector_batch(group.id, tensor_list);
// rc.Elapse("done!");
// }
//
// //add binary vectors one by one
// {
// server::TimeRecorder rc("Add " + std::to_string(count) + " binary vectors one by one");
// for (int64_t k = 0; k < count; k++) {
// session.interface()->add_binary_vector(group.id, bin_tensor_list.tensor_list[k]);
// if(k%1000 == 0) {
// CLIENT_LOG_INFO << "add binary vector no." << k;
// }
// }
// rc.Elapse("done!");
// }
//add vectors one by one
{
server::TimeRecorder rc("Add " + std::to_string(count) + " vectors one by one");
for (int64_t k = 0; k < count; k++) {
session.interface()->add_vector(group.id, tensor_list.tensor_list[k]);
if(k%1000 == 0) {
CLIENT_LOG_INFO << "add normal vector no." << k;
}
}
rc.Elapse("done!");
}
//add vectors in one batch
{
server::TimeRecorder rc("Add " + std::to_string(count) + " vectors in one batch");
session.interface()->add_vector_batch(group.id, tensor_list);
rc.Elapse("done!");
}
//add binary vectors one by one
{
server::TimeRecorder rc("Add " + std::to_string(count) + " binary vectors one by one");
for (int64_t k = 0; k < count; k++) {
session.interface()->add_binary_vector(group.id, bin_tensor_list.tensor_list[k]);
if(k%1000 == 0) {
CLIENT_LOG_INFO << "add binary vector no." << k;
}
}
rc.Elapse("done!");
}
//add binary vectors in one batch
{
@ -134,11 +134,41 @@ void ClientApp::Run(const std::string &config_file) {
std::cout << "Search result: " << std::endl;
for(auto id : res.id_list) {
std::cout << id << std::endl;
std::cout << "\t" << id << std::endl;
}
rc.Elapse("done!");
}
//search binary vector
{
server::TimeRecorder rc("Search binary batch top_k");
VecBinaryTensorList tensor_list;
for(int32_t k = 350; k < 360; k++) {
VecBinaryTensor bin_tensor;
bin_tensor.tensor.resize(dim * sizeof(double));
double* d_p = new double[dim];
for (int32_t i = 0; i < dim; i++) {
d_p[i] = (double)(i + k);
}
memcpy(const_cast<char*>(bin_tensor.tensor.data()), d_p, dim * sizeof(double));
tensor_list.tensor_list.emplace_back(bin_tensor);
}
VecSearchResultList res;
VecTimeRangeList range;
session.interface()->search_binary_vector_batch(res, group.id, 5, tensor_list, range);
std::cout << "Search binary batch result: " << std::endl;
for(size_t i = 0 ; i < res.result_list.size(); i++) {
std::cout << "No " << i << ":" << std::endl;
for(auto id : res.result_list[i].id_list) {
std::cout << "\t" << id << std::endl;
}
}
rc.Elapse("done!");
}
} catch (std::exception& ex) {
CLIENT_LOG_ERROR << "request encounter exception: " << ex.what();
}