mirror of
https://gitee.com/milvus-io/milvus.git
synced 2026-01-07 19:31:51 +08:00
135 lines
4.0 KiB
C++
135 lines
4.0 KiB
C++
|
|
#include <iostream>
|
|
#include <sstream>
|
|
#include "knowhere/index/vector_index/cpu_kdt_rng.h"
|
|
#include "knowhere/index/vector_index/definitions.h"
|
|
#include "knowhere/adapter/sptag.h"
|
|
#include "knowhere/adapter/structure.h"
|
|
|
|
|
|
using namespace zilliz::knowhere;
|
|
|
|
DatasetPtr
|
|
generate_dataset(int64_t n, int64_t d, int64_t base) {
|
|
auto elems = n * d;
|
|
auto p_data = (float *) malloc(elems * sizeof(float));
|
|
auto p_id = (int64_t *) malloc(elems * sizeof(int64_t));
|
|
assert(p_data != nullptr && p_id != nullptr);
|
|
|
|
for (auto i = 0; i < n; ++i) {
|
|
for (auto j = 0; j < d; ++j) {
|
|
p_data[i * d + j] = float(base + i);
|
|
}
|
|
p_id[i] = i;
|
|
}
|
|
|
|
std::vector<int64_t> shape{n, d};
|
|
auto tensor = ConstructFloatTensorSmart((uint8_t *) p_data, elems * sizeof(float), shape);
|
|
std::vector<TensorPtr> tensors{tensor};
|
|
std::vector<FieldPtr> tensor_fields{ConstructFloatField("data")};
|
|
auto tensor_schema = std::make_shared<Schema>(tensor_fields);
|
|
|
|
auto id_array = ConstructInt64ArraySmart((uint8_t *) p_id, n * sizeof(int64_t));
|
|
std::vector<ArrayPtr> arrays{id_array};
|
|
std::vector<FieldPtr> array_fields{ConstructInt64Field("id")};
|
|
auto array_schema = std::make_shared<Schema>(tensor_fields);
|
|
|
|
auto dataset = std::make_shared<Dataset>(std::move(arrays), array_schema,
|
|
std::move(tensors), tensor_schema);
|
|
|
|
return dataset;
|
|
}
|
|
|
|
DatasetPtr
|
|
generate_queries(int64_t n, int64_t d, int64_t k, int64_t base) {
|
|
size_t size = sizeof(float) * n * d;
|
|
auto v = (float *) malloc(size);
|
|
// TODO: check malloc
|
|
for (auto i = 0; i < n; ++i) {
|
|
for (auto j = 0; j < d; ++j) {
|
|
v[i * d + j] = float(base + i);
|
|
}
|
|
}
|
|
|
|
std::vector<TensorPtr> data;
|
|
auto buffer = MakeMutableBufferSmart((uint8_t *) v, size);
|
|
std::vector<int64_t> shape{n, d};
|
|
auto float_type = std::make_shared<arrow::FloatType>();
|
|
auto tensor = std::make_shared<Tensor>(float_type, buffer, shape);
|
|
data.push_back(tensor);
|
|
|
|
Config meta;
|
|
meta[META_ROWS] = int64_t (n);
|
|
meta[META_DIM] = int64_t (d);
|
|
meta[META_K] = int64_t (k);
|
|
|
|
auto type = std::make_shared<arrow::FloatType>();
|
|
auto field = std::make_shared<Field>("data", type);
|
|
std::vector<FieldPtr> fields{field};
|
|
auto schema = std::make_shared<Schema>(fields);
|
|
|
|
return std::make_shared<Dataset>(data, schema);
|
|
}
|
|
|
|
|
|
int
|
|
main(int argc, char *argv[]) {
|
|
auto kdt_index = std::make_shared<CPUKDTRNG>();
|
|
|
|
const auto d = 10;
|
|
const auto k = 3;
|
|
const auto nquery = 10;
|
|
|
|
// ID [0, 99]
|
|
auto train = generate_dataset(100, d, 0);
|
|
// ID [100]
|
|
auto base = generate_dataset(1, d, 0);
|
|
auto queries = generate_queries(nquery, d, k, 0);
|
|
|
|
// Build Preprocessor
|
|
auto preprocessor = kdt_index->BuildPreprocessor(train, Config());
|
|
|
|
// Set Preprocessor
|
|
kdt_index->set_preprocessor(preprocessor);
|
|
|
|
Config train_config;
|
|
train_config["TPTNumber"] = "64";
|
|
// Train
|
|
kdt_index->Train(train, train_config);
|
|
|
|
// Add
|
|
kdt_index->Add(base, Config());
|
|
|
|
auto binary = kdt_index->Serialize();
|
|
auto new_index = std::make_shared<CPUKDTRNG>();
|
|
new_index->Load(binary);
|
|
// auto new_index = kdt_index;
|
|
|
|
Config search_config;
|
|
search_config[META_K] = int64_t (k);
|
|
|
|
// Search
|
|
auto result = new_index->Search(queries, search_config);
|
|
|
|
// Print Result
|
|
{
|
|
auto ids = result->array()[0];
|
|
auto dists = result->array()[1];
|
|
|
|
std::stringstream ss_id;
|
|
std::stringstream ss_dist;
|
|
for (auto i = 0; i < nquery; i++) {
|
|
for (auto j = 0; j < k; ++j) {
|
|
ss_id << *ids->data()->GetValues<int64_t>(1, i * k + j) << " ";
|
|
ss_dist << *dists->data()->GetValues<float>(1, i * k + j) << " ";
|
|
}
|
|
ss_id << std::endl;
|
|
ss_dist << std::endl;
|
|
}
|
|
std::cout << "id\n" << ss_id.str() << std::endl;
|
|
std::cout << "dist\n" << ss_dist.str() << std::endl;
|
|
}
|
|
}
|
|
|
|
|