From 545d4725fb040c1dc392b76ca0e70c841e9eae80 Mon Sep 17 00:00:00 2001 From: Gao Date: Wed, 5 Jun 2024 14:31:57 +0800 Subject: [PATCH] fix: correct get vector data size for bf16/fp16/binary vector (#33377) related #22837 Signed-off-by: chasingegg --- internal/core/src/common/Types.h | 1 + internal/core/src/index/VectorDiskIndex.cpp | 4 +-- internal/core/src/index/VectorMemIndex.cpp | 6 ++-- internal/core/unittest/test_float16.cpp | 32 ++++++++++----------- 4 files changed, 22 insertions(+), 21 deletions(-) diff --git a/internal/core/src/common/Types.h b/internal/core/src/common/Types.h index e22f1e230e..4d14577828 100644 --- a/internal/core/src/common/Types.h +++ b/internal/core/src/common/Types.h @@ -57,6 +57,7 @@ using distance_t = float; using float16 = knowhere::fp16; using bfloat16 = knowhere::bf16; +using bin1 = knowhere::bin1; enum class DataType { NONE = 0, diff --git a/internal/core/src/index/VectorDiskIndex.cpp b/internal/core/src/index/VectorDiskIndex.cpp index 5e3f23dd87..73811b5077 100644 --- a/internal/core/src/index/VectorDiskIndex.cpp +++ b/internal/core/src/index/VectorDiskIndex.cpp @@ -449,10 +449,10 @@ VectorDiskAnnIndex::GetVector(const DatasetPtr dataset) const { auto row_num = res.value()->GetRows(); auto dim = res.value()->GetDim(); int64_t data_size; - if (is_in_bin_list(index_type)) { + if constexpr (std::is_same_v) { data_size = dim / 8 * row_num; } else { - data_size = dim * row_num * sizeof(float); + data_size = dim * row_num * sizeof(T); } std::vector raw_data; raw_data.resize(data_size); diff --git a/internal/core/src/index/VectorMemIndex.cpp b/internal/core/src/index/VectorMemIndex.cpp index fabbface68..580c568e10 100644 --- a/internal/core/src/index/VectorMemIndex.cpp +++ b/internal/core/src/index/VectorMemIndex.cpp @@ -669,10 +669,10 @@ VectorMemIndex::GetVector(const DatasetPtr dataset) const { auto row_num = res.value()->GetRows(); auto dim = res.value()->GetDim(); int64_t data_size; - if (is_in_bin_list(index_type)) { + if constexpr (std::is_same_v) { data_size = dim / 8 * row_num; } else { - data_size = dim * row_num * sizeof(float); + data_size = dim * row_num * sizeof(T); } std::vector raw_data; raw_data.resize(data_size); @@ -954,7 +954,7 @@ VectorMemIndex::LoadFromFileV2(const Config& config) { LOG_INFO("load vector index done"); } template class VectorMemIndex; -template class VectorMemIndex; +template class VectorMemIndex; template class VectorMemIndex; template class VectorMemIndex; diff --git a/internal/core/unittest/test_float16.cpp b/internal/core/unittest/test_float16.cpp index 04ca348d45..bf172a5d47 100644 --- a/internal/core/unittest/test_float16.cpp +++ b/internal/core/unittest/test_float16.cpp @@ -196,14 +196,14 @@ TEST(Float16, GetVector) { auto vector = result.get()->mutable_vectors()->float16_vector(); EXPECT_TRUE(vector.size() == num_inserted * dim * sizeof(float16)); - // EXPECT_TRUE(vector.size() == num_inserted * dim); - // for (size_t i = 0; i < num_inserted; ++i) { - // auto id = ids_ds->GetIds()[i]; - // for (size_t j = 0; j < 128; ++j) { - // EXPECT_TRUE(vector[i * dim + j] == - // fakevec[(id % per_batch) * dim + j]); - // } - // } + for (size_t i = 0; i < num_inserted; ++i) { + auto id = ids_ds->GetIds()[i]; + for (size_t j = 0; j < 128; ++j) { + EXPECT_TRUE( + reinterpret_cast(vector.data())[i * dim + j] == + fakevec[(id % per_batch) * dim + j]); + } + } } } @@ -453,14 +453,14 @@ TEST(BFloat16, GetVector) { auto vector = result.get()->mutable_vectors()->bfloat16_vector(); EXPECT_TRUE(vector.size() == num_inserted * dim * sizeof(bfloat16)); - // EXPECT_TRUE(vector.size() == num_inserted * dim); - // for (size_t i = 0; i < num_inserted; ++i) { - // auto id = ids_ds->GetIds()[i]; - // for (size_t j = 0; j < 128; ++j) { - // EXPECT_TRUE(vector[i * dim + j] == - // fakevec[(id % per_batch) * dim + j]); - // } - // } + for (size_t i = 0; i < num_inserted; ++i) { + auto id = ids_ds->GetIds()[i]; + for (size_t j = 0; j < 128; ++j) { + EXPECT_TRUE( + reinterpret_cast(vector.data())[i * dim + j] == + fakevec[(id % per_batch) * dim + j]); + } + } } }