mirror of
https://gitee.com/milvus-io/milvus.git
synced 2026-01-01 00:15:30 +08:00
Merge branch 'add_unittest' into 'branch-0.3.1'
Add unittest See merge request megasearch/milvus!185 Former-commit-id: fe37fe22833770f07ccf552364e3e7e31659232c
This commit is contained in:
commit
b7966df1f1
@ -4,6 +4,7 @@
|
||||
* Proprietary and confidential.
|
||||
******************************************************************************/
|
||||
#include <src/server/ServerConfig.h>
|
||||
#include <src/metrics/Metrics.h>
|
||||
#include "Log.h"
|
||||
|
||||
#include "src/cache/CpuCacheMgr.h"
|
||||
@ -16,55 +17,6 @@ namespace zilliz {
|
||||
namespace milvus {
|
||||
namespace engine {
|
||||
|
||||
struct FileIOWriter {
|
||||
std::fstream fs;
|
||||
std::string name;
|
||||
|
||||
FileIOWriter(const std::string &fname);
|
||||
~FileIOWriter();
|
||||
size_t operator()(void *ptr, size_t size);
|
||||
};
|
||||
|
||||
struct FileIOReader {
|
||||
std::fstream fs;
|
||||
std::string name;
|
||||
|
||||
FileIOReader(const std::string &fname);
|
||||
~FileIOReader();
|
||||
size_t operator()(void *ptr, size_t size);
|
||||
size_t operator()(void *ptr, size_t size, size_t pos);
|
||||
};
|
||||
|
||||
FileIOReader::FileIOReader(const std::string &fname) {
|
||||
name = fname;
|
||||
fs = std::fstream(name, std::ios::in | std::ios::binary);
|
||||
}
|
||||
|
||||
FileIOReader::~FileIOReader() {
|
||||
fs.close();
|
||||
}
|
||||
|
||||
size_t FileIOReader::operator()(void *ptr, size_t size) {
|
||||
fs.read(reinterpret_cast<char *>(ptr), size);
|
||||
}
|
||||
|
||||
size_t FileIOReader::operator()(void *ptr, size_t size, size_t pos) {
|
||||
return 0;
|
||||
}
|
||||
|
||||
FileIOWriter::FileIOWriter(const std::string &fname) {
|
||||
name = fname;
|
||||
fs = std::fstream(name, std::ios::out | std::ios::binary);
|
||||
}
|
||||
|
||||
FileIOWriter::~FileIOWriter() {
|
||||
fs.close();
|
||||
}
|
||||
|
||||
size_t FileIOWriter::operator()(void *ptr, size_t size) {
|
||||
fs.write(reinterpret_cast<char *>(ptr), size);
|
||||
}
|
||||
|
||||
ExecutionEngineImpl::ExecutionEngineImpl(uint16_t dimension,
|
||||
const std::string &location,
|
||||
EngineType type)
|
||||
@ -89,7 +41,7 @@ VecIndexPtr ExecutionEngineImpl::CreatetVecIndex(EngineType type) {
|
||||
break;
|
||||
}
|
||||
case EngineType::FAISS_IVFFLAT_GPU: {
|
||||
index = GetVecIndexFactory(IndexType::FAISS_IVFFLAT_GPU);
|
||||
index = GetVecIndexFactory(IndexType::FAISS_IVFFLAT_MIX);
|
||||
break;
|
||||
}
|
||||
case EngineType::FAISS_IVFFLAT_CPU: {
|
||||
@ -130,91 +82,34 @@ size_t ExecutionEngineImpl::PhysicalSize() const {
|
||||
}
|
||||
|
||||
Status ExecutionEngineImpl::Serialize() {
|
||||
auto binaryset = index_->Serialize();
|
||||
|
||||
FileIOWriter writer(location_);
|
||||
writer(¤t_type, sizeof(current_type));
|
||||
for (auto &iter: binaryset.binary_map_) {
|
||||
auto meta = iter.first.c_str();
|
||||
size_t meta_length = iter.first.length();
|
||||
writer(&meta_length, sizeof(meta_length));
|
||||
writer((void *) meta, meta_length);
|
||||
|
||||
auto binary = iter.second;
|
||||
size_t binary_length = binary->size;
|
||||
writer(&binary_length, sizeof(binary_length));
|
||||
writer((void *) binary->data.get(), binary_length);
|
||||
}
|
||||
write_index(index_, location_);
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status ExecutionEngineImpl::Load() {
|
||||
index_ = Load(location_);
|
||||
index_ = zilliz::milvus::cache::CpuCacheMgr::GetInstance()->GetIndex(location_);
|
||||
bool to_cache = false;
|
||||
auto start_time = METRICS_NOW_TIME;
|
||||
if (!index_) {
|
||||
index_ = read_index(location_);
|
||||
to_cache = true;
|
||||
ENGINE_LOG_DEBUG << "Disk io from: " << location_;
|
||||
}
|
||||
|
||||
if (to_cache) {
|
||||
Cache();
|
||||
auto end_time = METRICS_NOW_TIME;
|
||||
auto total_time = METRICS_MICROSECONDS(start_time, end_time);
|
||||
|
||||
server::Metrics::GetInstance().FaissDiskLoadDurationSecondsHistogramObserve(total_time);
|
||||
double total_size = Size();
|
||||
|
||||
server::Metrics::GetInstance().FaissDiskLoadSizeBytesHistogramObserve(total_size);
|
||||
server::Metrics::GetInstance().FaissDiskLoadIOSpeedGaugeSet(total_size / double(total_time));
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
VecIndexPtr ExecutionEngineImpl::Load(const std::string &location) {
|
||||
knowhere::BinarySet load_data_list;
|
||||
FileIOReader reader(location);
|
||||
reader.fs.seekg(0, reader.fs.end);
|
||||
size_t length = reader.fs.tellg();
|
||||
reader.fs.seekg(0);
|
||||
|
||||
size_t rp = 0;
|
||||
reader(¤t_type, sizeof(current_type));
|
||||
rp += sizeof(current_type);
|
||||
while (rp < length) {
|
||||
size_t meta_length;
|
||||
reader(&meta_length, sizeof(meta_length));
|
||||
rp += sizeof(meta_length);
|
||||
reader.fs.seekg(rp);
|
||||
|
||||
auto meta = new char[meta_length];
|
||||
reader(meta, meta_length);
|
||||
rp += meta_length;
|
||||
reader.fs.seekg(rp);
|
||||
|
||||
size_t bin_length;
|
||||
reader(&bin_length, sizeof(bin_length));
|
||||
rp += sizeof(bin_length);
|
||||
reader.fs.seekg(rp);
|
||||
|
||||
auto bin = new uint8_t[bin_length];
|
||||
reader(bin, bin_length);
|
||||
rp += bin_length;
|
||||
|
||||
auto binptr = std::make_shared<uint8_t>();
|
||||
binptr.reset(bin);
|
||||
load_data_list.Append(std::string(meta, meta_length), binptr, bin_length);
|
||||
}
|
||||
|
||||
auto index_type = IndexType::INVALID;
|
||||
switch (current_type) {
|
||||
case EngineType::FAISS_IDMAP: {
|
||||
index_type = IndexType::FAISS_IDMAP;
|
||||
break;
|
||||
}
|
||||
case EngineType::FAISS_IVFFLAT_CPU: {
|
||||
index_type = IndexType::FAISS_IVFFLAT_CPU;
|
||||
break;
|
||||
}
|
||||
case EngineType::FAISS_IVFFLAT_GPU: {
|
||||
index_type = IndexType::FAISS_IVFFLAT_GPU;
|
||||
break;
|
||||
}
|
||||
case EngineType::SPTAG_KDT_RNT_CPU: {
|
||||
index_type = IndexType::SPTAG_KDT_RNT_CPU;
|
||||
break;
|
||||
}
|
||||
default: {
|
||||
ENGINE_LOG_ERROR << "wrong index_type";
|
||||
return nullptr;
|
||||
}
|
||||
}
|
||||
|
||||
return LoadVecIndex(index_type, load_data_list);
|
||||
}
|
||||
|
||||
Status ExecutionEngineImpl::Merge(const std::string &location) {
|
||||
if (location == location_) {
|
||||
return Status::Error("Cannot Merge Self");
|
||||
@ -223,15 +118,17 @@ Status ExecutionEngineImpl::Merge(const std::string &location) {
|
||||
|
||||
auto to_merge = zilliz::milvus::cache::CpuCacheMgr::GetInstance()->GetIndex(location);
|
||||
if (!to_merge) {
|
||||
to_merge = Load(location);
|
||||
to_merge = read_index(location);
|
||||
}
|
||||
|
||||
auto file_index = std::dynamic_pointer_cast<BFIndex>(to_merge);
|
||||
index_->Add(file_index->Count(), file_index->GetRawVectors(), file_index->GetRawIds());
|
||||
return Status::OK();
|
||||
if (auto file_index = std::dynamic_pointer_cast<BFIndex>(to_merge)) {
|
||||
index_->Add(file_index->Count(), file_index->GetRawVectors(), file_index->GetRawIds());
|
||||
return Status::OK();
|
||||
} else {
|
||||
return Status::Error("file index type is not idmap");
|
||||
}
|
||||
}
|
||||
|
||||
// TODO(linxj): add config
|
||||
ExecutionEnginePtr
|
||||
ExecutionEngineImpl::BuildIndex(const std::string &location) {
|
||||
ENGINE_LOG_DEBUG << "Build index file: " << location << " from: " << location_;
|
||||
|
||||
@ -6,6 +6,7 @@
|
||||
|
||||
#include <src/utils/Log.h>
|
||||
#include "knowhere/index/vector_index/idmap.h"
|
||||
#include "knowhere/index/vector_index/gpu_ivf.h"
|
||||
|
||||
#include "vec_impl.h"
|
||||
#include "data_transfer.h"
|
||||
@ -98,6 +99,10 @@ int64_t VecIndexImpl::Count() {
|
||||
return index_->Count();
|
||||
}
|
||||
|
||||
IndexType VecIndexImpl::GetType() {
|
||||
return type;
|
||||
}
|
||||
|
||||
float *BFIndex::GetRawVectors() {
|
||||
auto raw_index = std::dynamic_pointer_cast<IDMAP>(index_);
|
||||
if (raw_index) { return raw_index->GetRawVectors(); }
|
||||
@ -126,6 +131,38 @@ void BFIndex::BuildAll(const long &nb,
|
||||
index_->Add(dataset, cfg);
|
||||
}
|
||||
|
||||
// TODO(linxj): add lock here.
|
||||
void IVFMixIndex::BuildAll(const long &nb,
|
||||
const float *xb,
|
||||
const long *ids,
|
||||
const Config &cfg,
|
||||
const long &nt,
|
||||
const float *xt) {
|
||||
dim = cfg["dim"].as<int>();
|
||||
auto dataset = GenDatasetWithIds(nb, dim, xb, ids);
|
||||
|
||||
auto preprocessor = index_->BuildPreprocessor(dataset, cfg);
|
||||
index_->set_preprocessor(preprocessor);
|
||||
auto nlist = int(nb / 1000000.0 * 16384);
|
||||
auto cfg_t = Config::object{{"nlist", nlist}, {"dim", dim}};
|
||||
auto model = index_->Train(dataset, cfg_t);
|
||||
index_->set_index_model(model);
|
||||
index_->Add(dataset, cfg);
|
||||
|
||||
if (auto device_index = std::dynamic_pointer_cast<GPUIVF>(index_)) {
|
||||
auto host_index = device_index->Copy_index_gpu_to_cpu();
|
||||
index_ = host_index;
|
||||
} else {
|
||||
// TODO(linxj): LOG ERROR
|
||||
}
|
||||
}
|
||||
|
||||
void IVFMixIndex::Load(const zilliz::knowhere::BinarySet &index_binary) {
|
||||
index_ = std::make_shared<IVF>();
|
||||
index_->Load(index_binary);
|
||||
dim = Dimension();
|
||||
}
|
||||
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@ -17,13 +17,15 @@ namespace engine {
|
||||
|
||||
class VecIndexImpl : public VecIndex {
|
||||
public:
|
||||
explicit VecIndexImpl(std::shared_ptr<zilliz::knowhere::VectorIndex> index) : index_(std::move(index)) {};
|
||||
explicit VecIndexImpl(std::shared_ptr<zilliz::knowhere::VectorIndex> index, const IndexType &type)
|
||||
: index_(std::move(index)), type(type) {};
|
||||
void BuildAll(const long &nb,
|
||||
const float *xb,
|
||||
const long *ids,
|
||||
const Config &cfg,
|
||||
const long &nt,
|
||||
const float *xt) override;
|
||||
IndexType GetType() override;
|
||||
int64_t Dimension() override;
|
||||
int64_t Count() override;
|
||||
void Add(const long &nb, const float *xb, const long *ids, const Config &cfg) override;
|
||||
@ -33,21 +35,36 @@ class VecIndexImpl : public VecIndex {
|
||||
|
||||
protected:
|
||||
int64_t dim = 0;
|
||||
IndexType type = IndexType::INVALID;
|
||||
std::shared_ptr<zilliz::knowhere::VectorIndex> index_ = nullptr;
|
||||
};
|
||||
|
||||
class BFIndex : public VecIndexImpl {
|
||||
class IVFMixIndex : public VecIndexImpl {
|
||||
public:
|
||||
explicit BFIndex(std::shared_ptr<zilliz::knowhere::VectorIndex> index) : VecIndexImpl(std::move(index)) {};
|
||||
void Build(const int64_t& d);
|
||||
float* GetRawVectors();
|
||||
explicit IVFMixIndex(std::shared_ptr<zilliz::knowhere::VectorIndex> index) : VecIndexImpl(std::move(index),
|
||||
IndexType::FAISS_IVFFLAT_MIX) {};
|
||||
void BuildAll(const long &nb,
|
||||
const float *xb,
|
||||
const long *ids,
|
||||
const Config &cfg,
|
||||
const long &nt,
|
||||
const float *xt) override;
|
||||
int64_t* GetRawIds();
|
||||
void Load(const zilliz::knowhere::BinarySet &index_binary) override;
|
||||
};
|
||||
|
||||
class BFIndex : public VecIndexImpl {
|
||||
public:
|
||||
explicit BFIndex(std::shared_ptr<zilliz::knowhere::VectorIndex> index) : VecIndexImpl(std::move(index),
|
||||
IndexType::FAISS_IDMAP) {};
|
||||
void Build(const int64_t &d);
|
||||
float *GetRawVectors();
|
||||
void BuildAll(const long &nb,
|
||||
const float *xb,
|
||||
const long *ids,
|
||||
const Config &cfg,
|
||||
const long &nt,
|
||||
const float *xt) override;
|
||||
int64_t *GetRawIds();
|
||||
};
|
||||
|
||||
}
|
||||
|
||||
@ -16,7 +16,56 @@ namespace zilliz {
|
||||
namespace milvus {
|
||||
namespace engine {
|
||||
|
||||
// TODO(linxj): index_type => enum struct
|
||||
struct FileIOWriter {
|
||||
std::fstream fs;
|
||||
std::string name;
|
||||
|
||||
FileIOWriter(const std::string &fname);
|
||||
~FileIOWriter();
|
||||
size_t operator()(void *ptr, size_t size);
|
||||
};
|
||||
|
||||
struct FileIOReader {
|
||||
std::fstream fs;
|
||||
std::string name;
|
||||
|
||||
FileIOReader(const std::string &fname);
|
||||
~FileIOReader();
|
||||
size_t operator()(void *ptr, size_t size);
|
||||
size_t operator()(void *ptr, size_t size, size_t pos);
|
||||
};
|
||||
|
||||
FileIOReader::FileIOReader(const std::string &fname) {
|
||||
name = fname;
|
||||
fs = std::fstream(name, std::ios::in | std::ios::binary);
|
||||
}
|
||||
|
||||
FileIOReader::~FileIOReader() {
|
||||
fs.close();
|
||||
}
|
||||
|
||||
size_t FileIOReader::operator()(void *ptr, size_t size) {
|
||||
fs.read(reinterpret_cast<char *>(ptr), size);
|
||||
}
|
||||
|
||||
size_t FileIOReader::operator()(void *ptr, size_t size, size_t pos) {
|
||||
return 0;
|
||||
}
|
||||
|
||||
FileIOWriter::FileIOWriter(const std::string &fname) {
|
||||
name = fname;
|
||||
fs = std::fstream(name, std::ios::out | std::ios::binary);
|
||||
}
|
||||
|
||||
FileIOWriter::~FileIOWriter() {
|
||||
fs.close();
|
||||
}
|
||||
|
||||
size_t FileIOWriter::operator()(void *ptr, size_t size) {
|
||||
fs.write(reinterpret_cast<char *>(ptr), size);
|
||||
}
|
||||
|
||||
|
||||
VecIndexPtr GetVecIndexFactory(const IndexType &type) {
|
||||
std::shared_ptr<zilliz::knowhere::VectorIndex> index;
|
||||
switch (type) {
|
||||
@ -32,6 +81,10 @@ VecIndexPtr GetVecIndexFactory(const IndexType &type) {
|
||||
index = std::make_shared<zilliz::knowhere::GPUIVF>(0);
|
||||
break;
|
||||
}
|
||||
case IndexType::FAISS_IVFFLAT_MIX: {
|
||||
index = std::make_shared<zilliz::knowhere::GPUIVF>(0);
|
||||
return std::make_shared<IVFMixIndex>(index);
|
||||
}
|
||||
case IndexType::FAISS_IVFPQ_CPU: {
|
||||
index = std::make_shared<zilliz::knowhere::IVFPQ>();
|
||||
break;
|
||||
@ -44,15 +97,15 @@ VecIndexPtr GetVecIndexFactory(const IndexType &type) {
|
||||
index = std::make_shared<zilliz::knowhere::CPUKDTRNG>();
|
||||
break;
|
||||
}
|
||||
//case IndexType::NSG: { // TODO(linxj): bug.
|
||||
// index = std::make_shared<zilliz::knowhere::NSG>();
|
||||
// break;
|
||||
//}
|
||||
//case IndexType::NSG: { // TODO(linxj): bug.
|
||||
// index = std::make_shared<zilliz::knowhere::NSG>();
|
||||
// break;
|
||||
//}
|
||||
default: {
|
||||
return nullptr;
|
||||
}
|
||||
}
|
||||
return std::make_shared<VecIndexImpl>(index);
|
||||
return std::make_shared<VecIndexImpl>(index, type);
|
||||
}
|
||||
|
||||
VecIndexPtr LoadVecIndex(const IndexType &index_type, const zilliz::knowhere::BinarySet &index_binary) {
|
||||
@ -61,6 +114,64 @@ VecIndexPtr LoadVecIndex(const IndexType &index_type, const zilliz::knowhere::Bi
|
||||
return index;
|
||||
}
|
||||
|
||||
VecIndexPtr read_index(const std::string &location) {
|
||||
knowhere::BinarySet load_data_list;
|
||||
FileIOReader reader(location);
|
||||
reader.fs.seekg(0, reader.fs.end);
|
||||
size_t length = reader.fs.tellg();
|
||||
reader.fs.seekg(0);
|
||||
|
||||
size_t rp = 0;
|
||||
auto current_type = IndexType::INVALID;
|
||||
reader(¤t_type, sizeof(current_type));
|
||||
rp += sizeof(current_type);
|
||||
while (rp < length) {
|
||||
size_t meta_length;
|
||||
reader(&meta_length, sizeof(meta_length));
|
||||
rp += sizeof(meta_length);
|
||||
reader.fs.seekg(rp);
|
||||
|
||||
auto meta = new char[meta_length];
|
||||
reader(meta, meta_length);
|
||||
rp += meta_length;
|
||||
reader.fs.seekg(rp);
|
||||
|
||||
size_t bin_length;
|
||||
reader(&bin_length, sizeof(bin_length));
|
||||
rp += sizeof(bin_length);
|
||||
reader.fs.seekg(rp);
|
||||
|
||||
auto bin = new uint8_t[bin_length];
|
||||
reader(bin, bin_length);
|
||||
rp += bin_length;
|
||||
|
||||
auto binptr = std::make_shared<uint8_t>();
|
||||
binptr.reset(bin);
|
||||
load_data_list.Append(std::string(meta, meta_length), binptr, bin_length);
|
||||
}
|
||||
|
||||
return LoadVecIndex(current_type, load_data_list);
|
||||
}
|
||||
|
||||
void write_index(VecIndexPtr index, const std::string &location) {
|
||||
auto binaryset = index->Serialize();
|
||||
auto index_type = index->GetType();
|
||||
|
||||
FileIOWriter writer(location);
|
||||
writer(&index_type, sizeof(IndexType));
|
||||
for (auto &iter: binaryset.binary_map_) {
|
||||
auto meta = iter.first.c_str();
|
||||
size_t meta_length = iter.first.length();
|
||||
writer(&meta_length, sizeof(meta_length));
|
||||
writer((void *) meta, meta_length);
|
||||
|
||||
auto binary = iter.second;
|
||||
int64_t binary_length = binary->size;
|
||||
writer(&binary_length, sizeof(binary_length));
|
||||
writer((void *) binary->data.get(), binary_length);
|
||||
}
|
||||
}
|
||||
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@ -20,6 +20,18 @@ namespace engine {
|
||||
// TODO(linxj): jsoncons => rapidjson or other.
|
||||
using Config = zilliz::knowhere::Config;
|
||||
|
||||
enum class IndexType {
|
||||
INVALID = 0,
|
||||
FAISS_IDMAP = 1,
|
||||
FAISS_IVFFLAT_CPU,
|
||||
FAISS_IVFFLAT_GPU,
|
||||
FAISS_IVFFLAT_MIX, // build on gpu and search on cpu
|
||||
FAISS_IVFPQ_CPU,
|
||||
FAISS_IVFPQ_GPU,
|
||||
SPTAG_KDT_RNT_CPU,
|
||||
//NSG,
|
||||
};
|
||||
|
||||
class VecIndex {
|
||||
public:
|
||||
virtual void BuildAll(const long &nb,
|
||||
@ -40,6 +52,8 @@ class VecIndex {
|
||||
long *ids,
|
||||
const Config &cfg = Config()) = 0;
|
||||
|
||||
virtual IndexType GetType() = 0;
|
||||
|
||||
virtual int64_t Dimension() = 0;
|
||||
|
||||
virtual int64_t Count() = 0;
|
||||
@ -51,16 +65,9 @@ class VecIndex {
|
||||
|
||||
using VecIndexPtr = std::shared_ptr<VecIndex>;
|
||||
|
||||
enum class IndexType {
|
||||
INVALID = 0,
|
||||
FAISS_IDMAP = 1,
|
||||
FAISS_IVFFLAT_CPU,
|
||||
FAISS_IVFFLAT_GPU,
|
||||
FAISS_IVFPQ_CPU,
|
||||
FAISS_IVFPQ_GPU,
|
||||
SPTAG_KDT_RNT_CPU,
|
||||
//NSG,
|
||||
};
|
||||
extern void write_index(VecIndexPtr index, const std::string &location);
|
||||
|
||||
extern VecIndexPtr read_index(const std::string &location);
|
||||
|
||||
extern VecIndexPtr GetVecIndexFactory(const IndexType &type);
|
||||
|
||||
|
||||
2
cpp/thirdparty/knowhere
vendored
2
cpp/thirdparty/knowhere
vendored
@ -1 +1 @@
|
||||
Subproject commit c3123501d62f69f9eacaa73ee96c0daeb24620a5
|
||||
Subproject commit ca99a6899be4e8a0806452656cf0f2be19d79c1a
|
||||
@ -28,11 +28,37 @@ class KnowhereWrapperTest
|
||||
|
||||
//auto generator = GetGenerateFactory(generator_type);
|
||||
auto generator = std::make_shared<DataGenBase>();
|
||||
generator->GenData(dim, nb, nq, xb, xq, ids, k, gt_ids);
|
||||
generator->GenData(dim, nb, nq, xb, xq, ids, k, gt_ids, gt_dis);
|
||||
|
||||
index_ = GetVecIndexFactory(index_type);
|
||||
}
|
||||
|
||||
void AssertResult(const std::vector<long> &ids, const std::vector<float> &dis) {
|
||||
EXPECT_EQ(ids.size(), nq * k);
|
||||
EXPECT_EQ(dis.size(), nq * k);
|
||||
|
||||
for (auto i = 0; i < nq; i++) {
|
||||
EXPECT_EQ(ids[i * k], gt_ids[i * k]);
|
||||
EXPECT_EQ(dis[i * k], gt_dis[i * k]);
|
||||
}
|
||||
|
||||
int match = 0;
|
||||
for (int i = 0; i < nq; ++i) {
|
||||
for (int j = 0; j < k; ++j) {
|
||||
for (int l = 0; l < k; ++l) {
|
||||
if (ids[i * nq + j] == gt_ids[i * nq + l]) match++;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
auto precision = float(match) / (nq * k);
|
||||
EXPECT_GT(precision, 0.5);
|
||||
std::cout << std::endl << "Precision: " << precision
|
||||
<< ", match: " << match
|
||||
<< ", total: " << nq * k
|
||||
<< std::endl;
|
||||
}
|
||||
|
||||
protected:
|
||||
IndexType index_type;
|
||||
Config train_cfg;
|
||||
@ -50,126 +76,88 @@ class KnowhereWrapperTest
|
||||
|
||||
// Ground Truth
|
||||
std::vector<long> gt_ids;
|
||||
std::vector<float> gt_dis;
|
||||
};
|
||||
|
||||
INSTANTIATE_TEST_CASE_P(WrapperParam, KnowhereWrapperTest,
|
||||
Values(
|
||||
// ["Index type", "Generator type", "dim", "nb", "nq", "k", "build config", "search config"]
|
||||
//["Index type", "Generator type", "dim", "nb", "nq", "k", "build config", "search config"]
|
||||
std::make_tuple(IndexType::FAISS_IVFFLAT_CPU, "Default",
|
||||
64, 10000, 10, 10,
|
||||
64, 100000, 10, 10,
|
||||
Config::object{{"nlist", 100}, {"dim", 64}},
|
||||
Config::object{{"dim", 64}, {"k", 10}, {"nprobe", 20}}
|
||||
Config::object{{"dim", 64}, {"k", 10}, {"nprobe", 10}}
|
||||
),
|
||||
std::make_tuple(IndexType::SPTAG_KDT_RNT_CPU, "Default",
|
||||
64, 10000, 10, 10,
|
||||
Config::object{{"TPTNumber", 1}, {"dim", 64}},
|
||||
//std::make_tuple(IndexType::FAISS_IVFFLAT_GPU, "Default",
|
||||
// 64, 10000, 10, 10,
|
||||
// Config::object{{"nlist", 100}, {"dim", 64}},
|
||||
// Config::object{{"dim", 64}, {"k", 10}, {"nprobe", 40}}
|
||||
//),
|
||||
std::make_tuple(IndexType::FAISS_IVFFLAT_MIX, "Default",
|
||||
64, 100000, 10, 10,
|
||||
Config::object{{"nlist", 100}, {"dim", 64}},
|
||||
Config::object{{"dim", 64}, {"k", 10}, {"nprobe", 10}}
|
||||
),
|
||||
std::make_tuple(IndexType::FAISS_IDMAP, "Default",
|
||||
64, 100000, 10, 10,
|
||||
Config::object{{"dim", 64}},
|
||||
Config::object{{"dim", 64}, {"k", 10}}
|
||||
)
|
||||
//std::make_tuple(IndexType::SPTAG_KDT_RNT_CPU, "Default",
|
||||
// 64, 10000, 10, 10,
|
||||
// Config::object{{"TPTNumber", 1}, {"dim", 64}},
|
||||
// Config::object{{"dim", 64}, {"k", 10}}
|
||||
//)
|
||||
)
|
||||
);
|
||||
|
||||
void AssertAnns(const std::vector<long> >,
|
||||
const std::vector<long> &res,
|
||||
const int &nq,
|
||||
const int &k) {
|
||||
EXPECT_EQ(res.size(), nq * k);
|
||||
|
||||
for (auto i = 0; i < nq; i++) {
|
||||
EXPECT_EQ(gt[i * k], res[i * k]);
|
||||
}
|
||||
|
||||
int match = 0;
|
||||
for (int i = 0; i < nq; ++i) {
|
||||
for (int j = 0; j < k; ++j) {
|
||||
for (int l = 0; l < k; ++l) {
|
||||
if (gt[i * nq + j] == res[i * nq + l]) match++;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// TODO(linxj): percision check
|
||||
EXPECT_GT(float(match/nq*k), 0.5);
|
||||
}
|
||||
|
||||
TEST_P(KnowhereWrapperTest, base_test) {
|
||||
std::vector<long> res_ids;
|
||||
float *D = new float[k * nq];
|
||||
res_ids.resize(nq * k);
|
||||
EXPECT_EQ(index_->GetType(), index_type);
|
||||
|
||||
auto elems = nq * k;
|
||||
std::vector<int64_t> res_ids(elems);
|
||||
std::vector<float> res_dis(elems);
|
||||
|
||||
index_->BuildAll(nb, xb.data(), ids.data(), train_cfg);
|
||||
index_->Search(nq, xq.data(), D, res_ids.data(), search_cfg);
|
||||
AssertAnns(gt_ids, res_ids, nq, k);
|
||||
delete[] D;
|
||||
index_->Search(nq, xq.data(), res_dis.data(), res_ids.data(), search_cfg);
|
||||
AssertResult(res_ids, res_dis);
|
||||
}
|
||||
|
||||
TEST_P(KnowhereWrapperTest, serialize_test) {
|
||||
std::vector<long> res_ids;
|
||||
float *D = new float[k * nq];
|
||||
res_ids.resize(nq * k);
|
||||
TEST_P(KnowhereWrapperTest, serialize) {
|
||||
EXPECT_EQ(index_->GetType(), index_type);
|
||||
|
||||
auto elems = nq * k;
|
||||
std::vector<int64_t> res_ids(elems);
|
||||
std::vector<float> res_dis(elems);
|
||||
index_->BuildAll(nb, xb.data(), ids.data(), train_cfg);
|
||||
index_->Search(nq, xq.data(), D, res_ids.data(), search_cfg);
|
||||
AssertAnns(gt_ids, res_ids, nq, k);
|
||||
index_->Search(nq, xq.data(), res_dis.data(), res_ids.data(), search_cfg);
|
||||
AssertResult(res_ids, res_dis);
|
||||
|
||||
{
|
||||
auto binaryset = index_->Serialize();
|
||||
//int fileno = 0;
|
||||
//const std::string &base_name = "/tmp/wrapper_serialize_test_bin_";
|
||||
//std::vector<std::string> filename_list;
|
||||
//std::vector<std::pair<std::string, size_t >> meta_list;
|
||||
//for (auto &iter: binaryset.binary_map_) {
|
||||
// const std::string &filename = base_name + std::to_string(fileno);
|
||||
// FileIOWriter writer(filename);
|
||||
// writer(iter.second->data.get(), iter.second->size);
|
||||
//
|
||||
// meta_list.push_back(std::make_pair(iter.first, iter.second.size));
|
||||
// filename_list.push_back(filename);
|
||||
// ++fileno;
|
||||
//}
|
||||
//
|
||||
//BinarySet load_data_list;
|
||||
//for (int i = 0; i < filename_list.size() && i < meta_list.size(); ++i) {
|
||||
// auto bin_size = meta_list[i].second;
|
||||
// FileIOReader reader(filename_list[i]);
|
||||
// std::vector<uint8_t> load_data(bin_size);
|
||||
// reader(load_data.data(), bin_size);
|
||||
// load_data_list.Append(meta_list[i].first, load_data);
|
||||
//}
|
||||
auto binary = index_->Serialize();
|
||||
auto type = index_->GetType();
|
||||
auto new_index = GetVecIndexFactory(type);
|
||||
new_index->Load(binary);
|
||||
EXPECT_EQ(new_index->Dimension(), index_->Dimension());
|
||||
EXPECT_EQ(new_index->Count(), index_->Count());
|
||||
|
||||
int fileno = 0;
|
||||
std::vector<std::string> filename_list;
|
||||
const std::string &base_name = "/tmp/wrapper_serialize_test_bin_";
|
||||
std::vector<std::pair<std::string, size_t >> meta_list;
|
||||
for (auto &iter: binaryset.binary_map_) {
|
||||
const std::string &filename = base_name + std::to_string(fileno);
|
||||
FileIOWriter writer(filename);
|
||||
writer(iter.second->data.get(), iter.second->size);
|
||||
|
||||
meta_list.emplace_back(std::make_pair(iter.first, iter.second->size));
|
||||
filename_list.push_back(filename);
|
||||
++fileno;
|
||||
}
|
||||
|
||||
BinarySet load_data_list;
|
||||
for (int i = 0; i < filename_list.size() && i < meta_list.size(); ++i) {
|
||||
auto bin_size = meta_list[i].second;
|
||||
FileIOReader reader(filename_list[i]);
|
||||
|
||||
auto load_data = new uint8_t[bin_size];
|
||||
reader(load_data, bin_size);
|
||||
auto data = std::make_shared<uint8_t>();
|
||||
data.reset(load_data);
|
||||
load_data_list.Append(meta_list[i].first, data, bin_size);
|
||||
}
|
||||
|
||||
|
||||
res_ids.clear();
|
||||
res_ids.resize(nq * k);
|
||||
auto new_index = GetVecIndexFactory(index_type);
|
||||
new_index->Load(load_data_list);
|
||||
new_index->Search(nq, xq.data(), D, res_ids.data(), search_cfg);
|
||||
AssertAnns(gt_ids, res_ids, nq, k);
|
||||
std::vector<int64_t> res_ids(elems);
|
||||
std::vector<float> res_dis(elems);
|
||||
new_index->Search(nq, xq.data(), res_dis.data(), res_ids.data(), search_cfg);
|
||||
AssertResult(res_ids, res_dis);
|
||||
}
|
||||
|
||||
delete[] D;
|
||||
{
|
||||
std::string file_location = "/tmp/whatever";
|
||||
write_index(index_, file_location);
|
||||
auto new_index = read_index(file_location);
|
||||
EXPECT_EQ(new_index->GetType(), index_type);
|
||||
EXPECT_EQ(new_index->Dimension(), index_->Dimension());
|
||||
EXPECT_EQ(new_index->Count(), index_->Count());
|
||||
|
||||
std::vector<int64_t> res_ids(elems);
|
||||
std::vector<float> res_dis(elems);
|
||||
new_index->Search(nq, xq.data(), res_dis.data(), res_ids.data(), search_cfg);
|
||||
AssertResult(res_ids, res_dis);
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@ -19,7 +19,7 @@ DataGenPtr GetGenerateFactory(const std::string &gen_type) {
|
||||
|
||||
void DataGenBase::GenData(const int &dim, const int &nb, const int &nq,
|
||||
float *xb, float *xq, long *ids,
|
||||
const int &k, long *gt_ids) {
|
||||
const int &k, long *gt_ids, float *gt_dis) {
|
||||
for (auto i = 0; i < nb; ++i) {
|
||||
for (auto j = 0; j < dim; ++j) {
|
||||
//p_data[i * d + j] = float(base + i);
|
||||
@ -35,8 +35,7 @@ void DataGenBase::GenData(const int &dim, const int &nb, const int &nq,
|
||||
faiss::IndexFlatL2 index(dim);
|
||||
//index.add_with_ids(nb, xb, ids);
|
||||
index.add(nb, xb);
|
||||
float *D = new float[k * nq];
|
||||
index.search(nq, xq, k, D, gt_ids);
|
||||
index.search(nq, xq, k, gt_dis, gt_ids);
|
||||
}
|
||||
|
||||
void DataGenBase::GenData(const int &dim,
|
||||
@ -46,36 +45,12 @@ void DataGenBase::GenData(const int &dim,
|
||||
std::vector<float> &xq,
|
||||
std::vector<long> &ids,
|
||||
const int &k,
|
||||
std::vector<long> >_ids) {
|
||||
std::vector<long> >_ids,
|
||||
std::vector<float> >_dis) {
|
||||
xb.resize(nb * dim);
|
||||
xq.resize(nq * dim);
|
||||
ids.resize(nb);
|
||||
gt_ids.resize(nq * k);
|
||||
GenData(dim, nb, nq, xb.data(), xq.data(), ids.data(), k, gt_ids.data());
|
||||
}
|
||||
|
||||
FileIOReader::FileIOReader(const std::string &fname) {
|
||||
name = fname;
|
||||
fs = std::fstream(name, std::ios::in | std::ios::binary);
|
||||
}
|
||||
|
||||
FileIOReader::~FileIOReader() {
|
||||
fs.close();
|
||||
}
|
||||
|
||||
size_t FileIOReader::operator()(void *ptr, size_t size) {
|
||||
fs.read(reinterpret_cast<char *>(ptr), size);
|
||||
}
|
||||
|
||||
FileIOWriter::FileIOWriter(const std::string &fname) {
|
||||
name = fname;
|
||||
fs = std::fstream(name, std::ios::out | std::ios::binary);
|
||||
}
|
||||
|
||||
FileIOWriter::~FileIOWriter() {
|
||||
fs.close();
|
||||
}
|
||||
|
||||
size_t FileIOWriter::operator()(void *ptr, size_t size) {
|
||||
fs.write(reinterpret_cast<char *>(ptr), size);
|
||||
gt_dis.resize(nq * k);
|
||||
GenData(dim, nb, nq, xb.data(), xq.data(), ids.data(), k, gt_ids.data(), gt_dis.data());
|
||||
}
|
||||
|
||||
@ -23,7 +23,7 @@ extern DataGenPtr GetGenerateFactory(const std::string &gen_type);
|
||||
class DataGenBase {
|
||||
public:
|
||||
virtual void GenData(const int &dim, const int &nb, const int &nq, float *xb, float *xq, long *ids,
|
||||
const int &k, long *gt_ids);
|
||||
const int &k, long *gt_ids, float *gt_dis);
|
||||
|
||||
virtual void GenData(const int &dim,
|
||||
const int &nb,
|
||||
@ -32,30 +32,14 @@ class DataGenBase {
|
||||
std::vector<float> &xq,
|
||||
std::vector<long> &ids,
|
||||
const int &k,
|
||||
std::vector<long> >_ids);
|
||||
std::vector<long> >_ids,
|
||||
std::vector<float> >_dis);
|
||||
};
|
||||
|
||||
|
||||
class SanityCheck : public DataGenBase {
|
||||
public:
|
||||
void GenData(const int &dim, const int &nb, const int &nq, float *xb, float *xq, long *ids,
|
||||
const int &k, long *gt_ids) override;
|
||||
};
|
||||
//class SanityCheck : public DataGenBase {
|
||||
// public:
|
||||
// void GenData(const int &dim, const int &nb, const int &nq, float *xb, float *xq, long *ids,
|
||||
// const int &k, long *gt_ids, float *gt_dis) override;
|
||||
//};
|
||||
|
||||
struct FileIOWriter {
|
||||
std::fstream fs;
|
||||
std::string name;
|
||||
|
||||
FileIOWriter(const std::string &fname);
|
||||
~FileIOWriter();
|
||||
size_t operator()(void *ptr, size_t size);
|
||||
};
|
||||
|
||||
struct FileIOReader {
|
||||
std::fstream fs;
|
||||
std::string name;
|
||||
|
||||
FileIOReader(const std::string &fname);
|
||||
~FileIOReader();
|
||||
size_t operator()(void *ptr, size_t size);
|
||||
};
|
||||
|
||||
@ -38,6 +38,10 @@ public:
|
||||
|
||||
}
|
||||
|
||||
engine::IndexType GetType() override {
|
||||
return engine::IndexType::INVALID;
|
||||
}
|
||||
|
||||
virtual void Add(const long &nb,
|
||||
const float *xb,
|
||||
const long *ids,
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user