From f9967cd744bfee1b7b7f55f6d73f178b334304ee Mon Sep 17 00:00:00 2001 From: groot Date: Fri, 10 May 2019 10:39:21 +0800 Subject: [PATCH] modify thrift api Former-commit-id: 744dec10d157280d9776a7507aaa428ca76aad66 --- cpp/src/thrift/VectorService.thrift | 12 +++ cpp/src/thrift/gen-cpp/VecService.cpp | 68 ++++++++--------- .../thrift/gen-cpp/VectorService_types.cpp | 72 ++++++++++++++---- cpp/src/thrift/gen-cpp/VectorService_types.h | 10 ++- cpp/test_client/src/ClientTest.cpp | 76 +++---------------- 5 files changed, 124 insertions(+), 114 deletions(-) diff --git a/cpp/src/thrift/VectorService.thrift b/cpp/src/thrift/VectorService.thrift index 47334c07df..e14b58bb1a 100644 --- a/cpp/src/thrift/VectorService.thrift +++ b/cpp/src/thrift/VectorService.thrift @@ -85,6 +85,12 @@ struct VecDateTime { 6: required i32 second; } +/** + * time_begin; time range begin + * begine_closed; true means '[', false means '(' + * time_end; set to true to return tensor double array + * end_closed; time range end + */ struct VecTimeRange { 1: required VecDateTime time_begin; 2: required bool begine_closed; @@ -92,9 +98,15 @@ struct VecTimeRange { 4: required bool end_closed; } +/** + * attrib_filter; search condition, for example: "color=red" + * time_ranges; search condition, for example: "date between 1999-02-12 and 2008-10-14" + * return_attribs; specify required attribute names + */ struct VecSearchFilter { 1: optional map attrib_filter; 2: optional list time_ranges; + 3: optional list return_attribs; } service VecService { diff --git a/cpp/src/thrift/gen-cpp/VecService.cpp b/cpp/src/thrift/gen-cpp/VecService.cpp index 148010f2a0..d99b2bfcef 100644 --- a/cpp/src/thrift/gen-cpp/VecService.cpp +++ b/cpp/src/thrift/gen-cpp/VecService.cpp @@ -937,14 +937,14 @@ uint32_t VecService_add_vector_batch_result::read(::apache::thrift::protocol::TP if (ftype == ::apache::thrift::protocol::T_LIST) { { this->success.clear(); - uint32_t _size93; - ::apache::thrift::protocol::TType _etype96; - xfer += iprot->readListBegin(_etype96, _size93); - this->success.resize(_size93); - uint32_t _i97; - for (_i97 = 0; _i97 < _size93; ++_i97) + uint32_t _size99; + ::apache::thrift::protocol::TType _etype102; + xfer += iprot->readListBegin(_etype102, _size99); + this->success.resize(_size99); + uint32_t _i103; + for (_i103 = 0; _i103 < _size99; ++_i103) { - xfer += iprot->readString(this->success[_i97]); + xfer += iprot->readString(this->success[_i103]); } xfer += iprot->readListEnd(); } @@ -983,10 +983,10 @@ uint32_t VecService_add_vector_batch_result::write(::apache::thrift::protocol::T xfer += oprot->writeFieldBegin("success", ::apache::thrift::protocol::T_LIST, 0); { xfer += oprot->writeListBegin(::apache::thrift::protocol::T_STRING, static_cast(this->success.size())); - std::vector ::const_iterator _iter98; - for (_iter98 = this->success.begin(); _iter98 != this->success.end(); ++_iter98) + std::vector ::const_iterator _iter104; + for (_iter104 = this->success.begin(); _iter104 != this->success.end(); ++_iter104) { - xfer += oprot->writeString((*_iter98)); + xfer += oprot->writeString((*_iter104)); } xfer += oprot->writeListEnd(); } @@ -1031,14 +1031,14 @@ uint32_t VecService_add_vector_batch_presult::read(::apache::thrift::protocol::T if (ftype == ::apache::thrift::protocol::T_LIST) { { (*(this->success)).clear(); - uint32_t _size99; - ::apache::thrift::protocol::TType _etype102; - xfer += iprot->readListBegin(_etype102, _size99); - (*(this->success)).resize(_size99); - uint32_t _i103; - for (_i103 = 0; _i103 < _size99; ++_i103) + uint32_t _size105; + ::apache::thrift::protocol::TType _etype108; + xfer += iprot->readListBegin(_etype108, _size105); + (*(this->success)).resize(_size105); + uint32_t _i109; + for (_i109 = 0; _i109 < _size105; ++_i109) { - xfer += iprot->readString((*(this->success))[_i103]); + xfer += iprot->readString((*(this->success))[_i109]); } xfer += iprot->readListEnd(); } @@ -1415,14 +1415,14 @@ uint32_t VecService_add_binary_vector_batch_result::read(::apache::thrift::proto if (ftype == ::apache::thrift::protocol::T_LIST) { { this->success.clear(); - uint32_t _size104; - ::apache::thrift::protocol::TType _etype107; - xfer += iprot->readListBegin(_etype107, _size104); - this->success.resize(_size104); - uint32_t _i108; - for (_i108 = 0; _i108 < _size104; ++_i108) + uint32_t _size110; + ::apache::thrift::protocol::TType _etype113; + xfer += iprot->readListBegin(_etype113, _size110); + this->success.resize(_size110); + uint32_t _i114; + for (_i114 = 0; _i114 < _size110; ++_i114) { - xfer += iprot->readString(this->success[_i108]); + xfer += iprot->readString(this->success[_i114]); } xfer += iprot->readListEnd(); } @@ -1461,10 +1461,10 @@ uint32_t VecService_add_binary_vector_batch_result::write(::apache::thrift::prot xfer += oprot->writeFieldBegin("success", ::apache::thrift::protocol::T_LIST, 0); { xfer += oprot->writeListBegin(::apache::thrift::protocol::T_STRING, static_cast(this->success.size())); - std::vector ::const_iterator _iter109; - for (_iter109 = this->success.begin(); _iter109 != this->success.end(); ++_iter109) + std::vector ::const_iterator _iter115; + for (_iter115 = this->success.begin(); _iter115 != this->success.end(); ++_iter115) { - xfer += oprot->writeString((*_iter109)); + xfer += oprot->writeString((*_iter115)); } xfer += oprot->writeListEnd(); } @@ -1509,14 +1509,14 @@ uint32_t VecService_add_binary_vector_batch_presult::read(::apache::thrift::prot if (ftype == ::apache::thrift::protocol::T_LIST) { { (*(this->success)).clear(); - uint32_t _size110; - ::apache::thrift::protocol::TType _etype113; - xfer += iprot->readListBegin(_etype113, _size110); - (*(this->success)).resize(_size110); - uint32_t _i114; - for (_i114 = 0; _i114 < _size110; ++_i114) + uint32_t _size116; + ::apache::thrift::protocol::TType _etype119; + xfer += iprot->readListBegin(_etype119, _size116); + (*(this->success)).resize(_size116); + uint32_t _i120; + for (_i120 = 0; _i120 < _size116; ++_i120) { - xfer += iprot->readString((*(this->success))[_i114]); + xfer += iprot->readString((*(this->success))[_i120]); } xfer += iprot->readListEnd(); } diff --git a/cpp/src/thrift/gen-cpp/VectorService_types.cpp b/cpp/src/thrift/gen-cpp/VectorService_types.cpp index 7d36b0a5e4..5fefd60fc5 100644 --- a/cpp/src/thrift/gen-cpp/VectorService_types.cpp +++ b/cpp/src/thrift/gen-cpp/VectorService_types.cpp @@ -1647,6 +1647,11 @@ void VecSearchFilter::__set_time_ranges(const std::vector & val) { this->time_ranges = val; __isset.time_ranges = true; } + +void VecSearchFilter::__set_return_attribs(const std::vector & val) { + this->return_attribs = val; +__isset.return_attribs = true; +} std::ostream& operator<<(std::ostream& out, const VecSearchFilter& obj) { obj.printTo(out); @@ -1718,6 +1723,26 @@ uint32_t VecSearchFilter::read(::apache::thrift::protocol::TProtocol* iprot) { xfer += iprot->skip(ftype); } break; + case 3: + if (ftype == ::apache::thrift::protocol::T_LIST) { + { + this->return_attribs.clear(); + uint32_t _size89; + ::apache::thrift::protocol::TType _etype92; + xfer += iprot->readListBegin(_etype92, _size89); + this->return_attribs.resize(_size89); + uint32_t _i93; + for (_i93 = 0; _i93 < _size89; ++_i93) + { + xfer += iprot->readString(this->return_attribs[_i93]); + } + xfer += iprot->readListEnd(); + } + this->__isset.return_attribs = true; + } else { + xfer += iprot->skip(ftype); + } + break; default: xfer += iprot->skip(ftype); break; @@ -1739,11 +1764,11 @@ uint32_t VecSearchFilter::write(::apache::thrift::protocol::TProtocol* oprot) co xfer += oprot->writeFieldBegin("attrib_filter", ::apache::thrift::protocol::T_MAP, 1); { xfer += oprot->writeMapBegin(::apache::thrift::protocol::T_STRING, ::apache::thrift::protocol::T_STRING, static_cast(this->attrib_filter.size())); - std::map ::const_iterator _iter89; - for (_iter89 = this->attrib_filter.begin(); _iter89 != this->attrib_filter.end(); ++_iter89) + std::map ::const_iterator _iter94; + for (_iter94 = this->attrib_filter.begin(); _iter94 != this->attrib_filter.end(); ++_iter94) { - xfer += oprot->writeString(_iter89->first); - xfer += oprot->writeString(_iter89->second); + xfer += oprot->writeString(_iter94->first); + xfer += oprot->writeString(_iter94->second); } xfer += oprot->writeMapEnd(); } @@ -1753,10 +1778,23 @@ uint32_t VecSearchFilter::write(::apache::thrift::protocol::TProtocol* oprot) co xfer += oprot->writeFieldBegin("time_ranges", ::apache::thrift::protocol::T_LIST, 2); { xfer += oprot->writeListBegin(::apache::thrift::protocol::T_STRUCT, static_cast(this->time_ranges.size())); - std::vector ::const_iterator _iter90; - for (_iter90 = this->time_ranges.begin(); _iter90 != this->time_ranges.end(); ++_iter90) + std::vector ::const_iterator _iter95; + for (_iter95 = this->time_ranges.begin(); _iter95 != this->time_ranges.end(); ++_iter95) { - xfer += (*_iter90).write(oprot); + xfer += (*_iter95).write(oprot); + } + xfer += oprot->writeListEnd(); + } + xfer += oprot->writeFieldEnd(); + } + if (this->__isset.return_attribs) { + xfer += oprot->writeFieldBegin("return_attribs", ::apache::thrift::protocol::T_LIST, 3); + { + xfer += oprot->writeListBegin(::apache::thrift::protocol::T_STRING, static_cast(this->return_attribs.size())); + std::vector ::const_iterator _iter96; + for (_iter96 = this->return_attribs.begin(); _iter96 != this->return_attribs.end(); ++_iter96) + { + xfer += oprot->writeString((*_iter96)); } xfer += oprot->writeListEnd(); } @@ -1771,18 +1809,21 @@ void swap(VecSearchFilter &a, VecSearchFilter &b) { using ::std::swap; swap(a.attrib_filter, b.attrib_filter); swap(a.time_ranges, b.time_ranges); + swap(a.return_attribs, b.return_attribs); swap(a.__isset, b.__isset); } -VecSearchFilter::VecSearchFilter(const VecSearchFilter& other91) { - attrib_filter = other91.attrib_filter; - time_ranges = other91.time_ranges; - __isset = other91.__isset; +VecSearchFilter::VecSearchFilter(const VecSearchFilter& other97) { + attrib_filter = other97.attrib_filter; + time_ranges = other97.time_ranges; + return_attribs = other97.return_attribs; + __isset = other97.__isset; } -VecSearchFilter& VecSearchFilter::operator=(const VecSearchFilter& other92) { - attrib_filter = other92.attrib_filter; - time_ranges = other92.time_ranges; - __isset = other92.__isset; +VecSearchFilter& VecSearchFilter::operator=(const VecSearchFilter& other98) { + attrib_filter = other98.attrib_filter; + time_ranges = other98.time_ranges; + return_attribs = other98.return_attribs; + __isset = other98.__isset; return *this; } void VecSearchFilter::printTo(std::ostream& out) const { @@ -1790,6 +1831,7 @@ void VecSearchFilter::printTo(std::ostream& out) const { out << "VecSearchFilter("; out << "attrib_filter="; (__isset.attrib_filter ? (out << to_string(attrib_filter)) : (out << "")); out << ", " << "time_ranges="; (__isset.time_ranges ? (out << to_string(time_ranges)) : (out << "")); + out << ", " << "return_attribs="; (__isset.return_attribs ? (out << to_string(return_attribs)) : (out << "")); out << ")"; } diff --git a/cpp/src/thrift/gen-cpp/VectorService_types.h b/cpp/src/thrift/gen-cpp/VectorService_types.h index e517934620..d9285efd80 100644 --- a/cpp/src/thrift/gen-cpp/VectorService_types.h +++ b/cpp/src/thrift/gen-cpp/VectorService_types.h @@ -597,9 +597,10 @@ void swap(VecTimeRange &a, VecTimeRange &b); std::ostream& operator<<(std::ostream& out, const VecTimeRange& obj); typedef struct _VecSearchFilter__isset { - _VecSearchFilter__isset() : attrib_filter(false), time_ranges(false) {} + _VecSearchFilter__isset() : attrib_filter(false), time_ranges(false), return_attribs(false) {} bool attrib_filter :1; bool time_ranges :1; + bool return_attribs :1; } _VecSearchFilter__isset; class VecSearchFilter : public virtual ::apache::thrift::TBase { @@ -613,6 +614,7 @@ class VecSearchFilter : public virtual ::apache::thrift::TBase { virtual ~VecSearchFilter() throw(); std::map attrib_filter; std::vector time_ranges; + std::vector return_attribs; _VecSearchFilter__isset __isset; @@ -620,6 +622,8 @@ class VecSearchFilter : public virtual ::apache::thrift::TBase { void __set_time_ranges(const std::vector & val); + void __set_return_attribs(const std::vector & val); + bool operator == (const VecSearchFilter & rhs) const { if (__isset.attrib_filter != rhs.__isset.attrib_filter) @@ -630,6 +634,10 @@ class VecSearchFilter : public virtual ::apache::thrift::TBase { return false; else if (__isset.time_ranges && !(time_ranges == rhs.time_ranges)) return false; + if (__isset.return_attribs != rhs.__isset.return_attribs) + return false; + else if (__isset.return_attribs && !(return_attribs == rhs.return_attribs)) + return false; return true; } bool operator != (const VecSearchFilter &rhs) const { diff --git a/cpp/test_client/src/ClientTest.cpp b/cpp/test_client/src/ClientTest.cpp index 937e527777..7d73522ba1 100644 --- a/cpp/test_client/src/ClientTest.cpp +++ b/cpp/test_client/src/ClientTest.cpp @@ -119,57 +119,6 @@ namespace { } } -//void ClientTest::LoopTest() { -// server::TimeRecorder rc("LoopTest"); -// -// std::string address, protocol; -// int32_t port = 0; -// GetServerAddress(address, port, protocol); -// client::ClientSession session(address, port, protocol); -// -// rc.Record("connection"); -// -// //add group -// VecGroup group; -// group.id = "loop_group"; -// group.dimension = VEC_DIMENSION; -// group.index_type = 0; -// session.interface()->add_group(group); -// rc.Record("add group"); -// -// const int64_t batch = 10000; -// for(int64_t i = 0; i < 1000; i++) { -// { -// VecBinaryTensorList bin_tensor_list; -// BuildVectors(i * batch, (i + 1) * batch, nullptr, &bin_tensor_list); -// rc.Record("build batch no." + std::to_string(i)); -// -// std::vector ids; -// session.interface()->add_binary_vector_batch(ids, group.id, bin_tensor_list); -// rc.Record("add batch no." + std::to_string(i)); -// } -// -// sleep(1); -// rc.Record("sleep 1 second"); -// -// VecTensor tensor; -// for (int32_t k = 0; k < VEC_DIMENSION; k++) { -// tensor.tensor.push_back((double) (k + i*666)); -// } -// -// //do search -// VecSearchResult res; -// VecSearchFilter filter; -// session.interface()->search_vector(res, group.id, 10, tensor, filter); -// rc.Record("search finish"); -// -// std::cout << "Search result: " << std::endl; -// for(VecSearchResultItem& item : res.result_list) { -// std::cout << "\t" << item.uid << std::endl; -// } -// } -//} - TEST(AddVector, CLIENT_TEST) { try { std::string address, protocol; @@ -301,23 +250,22 @@ TEST(SearchVector, CLIENT_TEST) { ASSERT_TRUE(!res.result_list[0].uid.empty()); } -// //empty search -// date.day > 0 ? date.day -= 1 : date.day += 1; -// range.time_begin = date; -// range.time_end = date; -// time_ranges.clear(); -// time_ranges.emplace_back(range); -// filter.__set_time_ranges(time_ranges); -// session.interface()->search_vector(res, GetGroupID(), top_k, tensor, filter); -// -// ASSERT_EQ(res.result_list.size(), 0); + //empty search + date.day > 0 ? date.day -= 1 : date.day += 1; + range.time_begin = date; + range.time_end = date; + time_ranges.clear(); + time_ranges.emplace_back(range); + filter.__set_time_ranges(time_ranges); + session.interface()->search_vector(res, GetGroupID(), TOP_K, tensor, filter); + + ASSERT_EQ(res.result_list.size(), 0); } //search binary vector { const int32_t anchor_index = BATCH_COUNT + 200; const int32_t search_count = 10; - const int64_t top_k = 5; server::TimeRecorder rc("Search binary batch top_k"); VecBinaryTensorList tensor_list; for(int32_t k = anchor_index; k < anchor_index + search_count; k++) { @@ -333,7 +281,7 @@ TEST(SearchVector, CLIENT_TEST) { VecSearchResultList res; VecSearchFilter filter; - session.interface()->search_binary_vector_batch(res, GetGroupID(), top_k, tensor_list, filter); + session.interface()->search_binary_vector_batch(res, GetGroupID(), TOP_K, tensor_list, filter); std::cout << "Search binary batch result: " << std::endl; for(size_t i = 0 ; i < res.result_list.size(); i++) { @@ -350,7 +298,7 @@ TEST(SearchVector, CLIENT_TEST) { ASSERT_EQ(res.result_list.size(), search_count); for(size_t i = 0 ; i < res.result_list.size(); i++) { - ASSERT_EQ(res.result_list[i].result_list.size(), (uint64_t) top_k); + ASSERT_EQ(res.result_list[i].result_list.size(), (uint64_t) TOP_K); ASSERT_TRUE(!res.result_list[i].result_list.empty()); } }