milvus/cpp/src/wrapper/knowhere/vec_index.cpp
xj.lin c6ce5772ba fix bad alloc and add idmap
Former-commit-id: bd2686574ad9010e33dcf34f3ae45308d5b3c971
2019-07-10 16:28:11 +08:00

178 lines
5.1 KiB
C++

////////////////////////////////////////////////////////////////////////////////
// Copyright 上海赜睿信息科技有限公司(Zilliz) - All Rights Reserved
// Unauthorized copying of this file, via any medium is strictly prohibited.
// Proprietary and confidential.
////////////////////////////////////////////////////////////////////////////////
#include "knowhere/index/vector_index/ivf.h"
#include "knowhere/index/vector_index/idmap.h"
#include "knowhere/index/vector_index/gpu_ivf.h"
#include "knowhere/index/vector_index/cpu_kdt_rng.h"
#include "vec_index.h"
#include "vec_impl.h"
namespace zilliz {
namespace milvus {
namespace engine {
struct FileIOWriter {
std::fstream fs;
std::string name;
FileIOWriter(const std::string &fname);
~FileIOWriter();
size_t operator()(void *ptr, size_t size);
};
struct FileIOReader {
std::fstream fs;
std::string name;
FileIOReader(const std::string &fname);
~FileIOReader();
size_t operator()(void *ptr, size_t size);
size_t operator()(void *ptr, size_t size, size_t pos);
};
FileIOReader::FileIOReader(const std::string &fname) {
name = fname;
fs = std::fstream(name, std::ios::in | std::ios::binary);
}
FileIOReader::~FileIOReader() {
fs.close();
}
size_t FileIOReader::operator()(void *ptr, size_t size) {
fs.read(reinterpret_cast<char *>(ptr), size);
}
size_t FileIOReader::operator()(void *ptr, size_t size, size_t pos) {
return 0;
}
FileIOWriter::FileIOWriter(const std::string &fname) {
name = fname;
fs = std::fstream(name, std::ios::out | std::ios::binary);
}
FileIOWriter::~FileIOWriter() {
fs.close();
}
size_t FileIOWriter::operator()(void *ptr, size_t size) {
fs.write(reinterpret_cast<char *>(ptr), size);
}
VecIndexPtr GetVecIndexFactory(const IndexType &type) {
std::shared_ptr<zilliz::knowhere::VectorIndex> index;
switch (type) {
case IndexType::FAISS_IDMAP: {
index = std::make_shared<zilliz::knowhere::IDMAP>();
return std::make_shared<BFIndex>(index);
}
case IndexType::FAISS_IVFFLAT_CPU: {
index = std::make_shared<zilliz::knowhere::IVF>();
break;
}
case IndexType::FAISS_IVFFLAT_GPU: {
index = std::make_shared<zilliz::knowhere::GPUIVF>(0);
break;
}
case IndexType::FAISS_IVFFLAT_MIX: {
index = std::make_shared<zilliz::knowhere::GPUIVF>(0);
return std::make_shared<IVFMixIndex>(index);
}
case IndexType::FAISS_IVFPQ_CPU: {
index = std::make_shared<zilliz::knowhere::IVFPQ>();
break;
}
case IndexType::FAISS_IVFPQ_GPU: {
index = std::make_shared<zilliz::knowhere::GPUIVFPQ>(0);
break;
}
case IndexType::SPTAG_KDT_RNT_CPU: {
index = std::make_shared<zilliz::knowhere::CPUKDTRNG>();
break;
}
//case IndexType::NSG: { // TODO(linxj): bug.
// index = std::make_shared<zilliz::knowhere::NSG>();
// break;
//}
default: {
return nullptr;
}
}
return std::make_shared<VecIndexImpl>(index, type);
}
VecIndexPtr LoadVecIndex(const IndexType &index_type, const zilliz::knowhere::BinarySet &index_binary) {
auto index = GetVecIndexFactory(index_type);
index->Load(index_binary);
return index;
}
VecIndexPtr read_index(const std::string &location) {
knowhere::BinarySet load_data_list;
FileIOReader reader(location);
reader.fs.seekg(0, reader.fs.end);
size_t length = reader.fs.tellg();
reader.fs.seekg(0);
size_t rp = 0;
auto current_type = IndexType::INVALID;
reader(&current_type, sizeof(current_type));
rp += sizeof(current_type);
while (rp < length) {
size_t meta_length;
reader(&meta_length, sizeof(meta_length));
rp += sizeof(meta_length);
reader.fs.seekg(rp);
auto meta = new char[meta_length];
reader(meta, meta_length);
rp += meta_length;
reader.fs.seekg(rp);
size_t bin_length;
reader(&bin_length, sizeof(bin_length));
rp += sizeof(bin_length);
reader.fs.seekg(rp);
auto bin = new uint8_t[bin_length];
reader(bin, bin_length);
rp += bin_length;
auto binptr = std::make_shared<uint8_t>();
binptr.reset(bin);
load_data_list.Append(std::string(meta, meta_length), binptr, bin_length);
}
return LoadVecIndex(current_type, load_data_list);
}
void write_index(VecIndexPtr index, const std::string &location) {
auto binaryset = index->Serialize();
auto index_type = index->GetType();
FileIOWriter writer(location);
writer(&index_type, sizeof(IndexType));
for (auto &iter: binaryset.binary_map_) {
auto meta = iter.first.c_str();
size_t meta_length = iter.first.length();
writer(&meta_length, sizeof(meta_length));
writer((void *) meta, meta_length);
auto binary = iter.second;
int64_t binary_length = binary->size;
writer(&binary_length, sizeof(binary_length));
writer((void *) binary->data.get(), binary_length);
}
}
}
}
}