From 1141bbbb061e38eca0f6a283199810e7eddf5676 Mon Sep 17 00:00:00 2001 From: "xiaojun.lin" Date: Sun, 13 Oct 2019 16:57:34 +0800 Subject: [PATCH] update v2 Former-commit-id: 1240499e9e5f0042a2296300b00588ed11dc07c3 --- cpp/src/core/unittest/test_ivf.cpp | 48 ++++++++++++++++++++++++++++++ 1 file changed, 48 insertions(+) diff --git a/cpp/src/core/unittest/test_ivf.cpp b/cpp/src/core/unittest/test_ivf.cpp index 1d49d91b7c..987692deb5 100644 --- a/cpp/src/core/unittest/test_ivf.cpp +++ b/cpp/src/core/unittest/test_ivf.cpp @@ -696,6 +696,54 @@ TEST_F(GPURESTEST, copyandsearch) { std::thread search_thread(search_func); std::thread load_thread(load_func); + search_thread.join(); + load_thread.join(); + tc.RecordSection("Copy&search total"); +} + +TEST_F(GPURESTEST, TrainAndSearch) { + index_type = "GPUIVFSQ"; + index_ = IndexFactory(index_type); + + auto conf = std::make_shared(); + conf->nlist = 1638; + conf->d = dim; + conf->gpu_id = device_id; + conf->metric_type = knowhere::METRICTYPE::L2; + conf->k = k; + conf->nbits = 8; + conf->nprobe = 1; + + auto preprocessor = index_->BuildPreprocessor(base_dataset, conf); + index_->set_preprocessor(preprocessor); + auto model = index_->Train(base_dataset, conf); + auto new_index = IndexFactory(index_type); + new_index->set_index_model(model); + new_index->Add(base_dataset, conf); + auto cpu_idx = knowhere::cloner::CopyGpuToCpu(new_index, knowhere::Config()); + cpu_idx->Seal(); + auto search_idx = knowhere::cloner::CopyCpuToGpu(cpu_idx, device_id, knowhere::Config()); + + constexpr int train_count = 1; + constexpr int search_count = 5000; + auto train_stage = [&] { + for (int i = 0; i < train_count; ++i) { + auto model = index_->Train(base_dataset, conf); + auto test_idx = IndexFactory(index_type); + test_idx->set_index_model(model); + test_idx->Add(base_dataset, conf); + } + }; + auto search_stage = [&](knowhere::VectorIndexPtr& search_idx) { + for (int i = 0; i < search_count; ++i) { + auto result = search_idx->Search(query_dataset, conf); + AssertAnns(result, nq, k); + } + }; + + // TimeRecorder tc("record"); + // train_stage(); + // tc.RecordSection("train cost"); // search_stage(search_idx); // tc.RecordSection("search cost");