mirror of
https://gitee.com/milvus-io/milvus.git
synced 2026-01-07 19:31:51 +08:00
SQ8H in GPU part2
Former-commit-id: 2d8f5d2858f0ca3e4edf02ae53bc9c195a3c91a3
This commit is contained in:
parent
bf24396571
commit
6b148639ef
@ -180,7 +180,7 @@ IVFSQHybrid::UnsetQuantizer() {
|
||||
ivf_index->quantizer = nullptr;
|
||||
}
|
||||
|
||||
void
|
||||
VectorIndexPtr
|
||||
IVFSQHybrid::LoadData(const knowhere::QuantizerPtr& q, const Config& conf) {
|
||||
auto quantizer_conf = std::dynamic_pointer_cast<QuantizerCfg>(conf);
|
||||
if (quantizer_conf != nullptr) {
|
||||
@ -207,8 +207,10 @@ IVFSQHybrid::LoadData(const knowhere::QuantizerPtr& q, const Config& conf) {
|
||||
index_composition->mode = quantizer_conf->mode; // only 2
|
||||
|
||||
auto gpu_index = faiss::gpu::index_cpu_to_gpu(res->faiss_res.get(), gpu_id_, index_composition, &option);
|
||||
index_.reset(gpu_index);
|
||||
gpu_mode = 2; // all in gpu
|
||||
std::shared_ptr<faiss::Index> new_idx;
|
||||
new_idx.reset(gpu_index);
|
||||
auto sq_idx = std::make_shared<IVFSQHybrid>(new_idx, gpu_id_, res);
|
||||
return sq_idx;
|
||||
} else {
|
||||
KNOWHERE_THROW_MSG("CopyCpuToGpu Error, can't get gpu_resource");
|
||||
}
|
||||
|
||||
@ -60,8 +60,7 @@ class IVFSQHybrid : public GPUIVFSQ {
|
||||
void
|
||||
UnsetQuantizer();
|
||||
|
||||
// todo(xiaojun): return void => VecIndex
|
||||
void
|
||||
VectorIndexPtr
|
||||
LoadData(const knowhere::QuantizerPtr& q, const Config& conf);
|
||||
|
||||
IndexModelPtr
|
||||
|
||||
@ -253,9 +253,9 @@ TEST_P(IVFTest, hybrid) {
|
||||
quantizer_conf->gpu_id = device_id;
|
||||
auto q = hybrid_2_idx->LoadQuantizer(quantizer_conf);
|
||||
quantizer_conf->mode = 2;
|
||||
hybrid_2_idx->LoadData(q, quantizer_conf);
|
||||
auto gpu_idx = hybrid_2_idx->LoadData(q, quantizer_conf);
|
||||
|
||||
auto result = hybrid_2_idx->Search(query_dataset, conf);
|
||||
auto result = gpu_idx->Search(query_dataset, conf);
|
||||
AssertAnns(result, nq, conf->k);
|
||||
PrintResult(result, nq, k);
|
||||
}
|
||||
|
||||
@ -256,11 +256,14 @@ ExecutionEngineImpl::CopyToGpu(uint64_t device_id, bool hybrid) {
|
||||
conf->gpu_id = device_id;
|
||||
|
||||
if (quantizer) {
|
||||
std::cout << "cache hit" << std::endl;
|
||||
// cache hit
|
||||
conf->mode = 2;
|
||||
index_->SetQuantizer(quantizer->Data());
|
||||
index_->LoadData(quantizer->Data(), conf);
|
||||
auto new_index = index_->LoadData(quantizer->Data(), conf);
|
||||
index_ = new_index;
|
||||
} else {
|
||||
std::cout << "cache miss" << std::endl;
|
||||
// cache hit
|
||||
// cache miss
|
||||
if (index_ == nullptr) {
|
||||
ENGINE_LOG_ERROR << "ExecutionEngineImpl: index is null, failed to copy to gpu";
|
||||
@ -268,9 +271,9 @@ ExecutionEngineImpl::CopyToGpu(uint64_t device_id, bool hybrid) {
|
||||
}
|
||||
conf->mode = 1;
|
||||
auto q = index_->LoadQuantizer(conf);
|
||||
index_->SetQuantizer(q);
|
||||
conf->mode = 2;
|
||||
index_->LoadData(q, conf);
|
||||
auto new_index = index_->LoadData(q, conf);
|
||||
index_ = new_index;
|
||||
|
||||
// cache
|
||||
auto cached_quantizer = std::make_shared<CachedQuantizer>(q);
|
||||
@ -445,7 +448,9 @@ ExecutionEngineImpl::Search(int64_t n, const float* data, int64_t k, int64_t npr
|
||||
|
||||
auto status = index_->Search(n, data, distances, labels, conf);
|
||||
|
||||
HybridUnset();
|
||||
if (hybrid) {
|
||||
HybridUnset();
|
||||
}
|
||||
|
||||
if (!status.ok()) {
|
||||
ENGINE_LOG_ERROR << "Search error";
|
||||
|
||||
@ -315,24 +315,21 @@ IVFHybridIndex::UnsetQuantizer() {
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status
|
||||
VecIndexPtr
|
||||
IVFHybridIndex::LoadData(const knowhere::QuantizerPtr& q, const Config& conf) {
|
||||
try {
|
||||
// TODO(linxj): Hardcode here
|
||||
if (auto new_idx = std::dynamic_pointer_cast<knowhere::IVFSQHybrid>(index_)) {
|
||||
new_idx->LoadData(q, conf);
|
||||
return std::make_shared<IVFHybridIndex>(new_idx->LoadData(q, conf), type);
|
||||
} else {
|
||||
WRAPPER_LOG_ERROR << "Hybrid mode not support for index type: " << int(type);
|
||||
return Status(KNOWHERE_ERROR, "not support");
|
||||
}
|
||||
} catch (knowhere::KnowhereException& e) {
|
||||
WRAPPER_LOG_ERROR << e.what();
|
||||
return Status(KNOWHERE_UNEXPECTED_ERROR, e.what());
|
||||
} catch (std::exception& e) {
|
||||
WRAPPER_LOG_ERROR << e.what();
|
||||
return Status(KNOWHERE_ERROR, e.what());
|
||||
}
|
||||
return Status::OK();
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
} // namespace engine
|
||||
|
||||
@ -106,7 +106,7 @@ class IVFHybridIndex : public IVFMixIndex {
|
||||
Status
|
||||
UnsetQuantizer() override;
|
||||
|
||||
Status
|
||||
VecIndexPtr
|
||||
LoadData(const knowhere::QuantizerPtr& q, const Config& conf) override;
|
||||
};
|
||||
|
||||
|
||||
@ -103,9 +103,9 @@ class VecIndex : public cache::DataObj {
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
virtual Status
|
||||
virtual VecIndexPtr
|
||||
LoadData(const knowhere::QuantizerPtr& q, const Config& conf) {
|
||||
return Status::OK();
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
virtual Status
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user