fix: correct get vector data size for bf16/fp16/binary vector (#33377)

related #22837

Signed-off-by: chasingegg <chao.gao@zilliz.com>
This commit is contained in:
Gao 2024-06-05 14:31:57 +08:00 committed by GitHub
parent 597f4c5e03
commit 545d4725fb
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 22 additions and 21 deletions

View File

@ -57,6 +57,7 @@ using distance_t = float;
using float16 = knowhere::fp16;
using bfloat16 = knowhere::bf16;
using bin1 = knowhere::bin1;
enum class DataType {
NONE = 0,

View File

@ -449,10 +449,10 @@ VectorDiskAnnIndex<T>::GetVector(const DatasetPtr dataset) const {
auto row_num = res.value()->GetRows();
auto dim = res.value()->GetDim();
int64_t data_size;
if (is_in_bin_list(index_type)) {
if constexpr (std::is_same_v<T, bin1>) {
data_size = dim / 8 * row_num;
} else {
data_size = dim * row_num * sizeof(float);
data_size = dim * row_num * sizeof(T);
}
std::vector<uint8_t> raw_data;
raw_data.resize(data_size);

View File

@ -669,10 +669,10 @@ VectorMemIndex<T>::GetVector(const DatasetPtr dataset) const {
auto row_num = res.value()->GetRows();
auto dim = res.value()->GetDim();
int64_t data_size;
if (is_in_bin_list(index_type)) {
if constexpr (std::is_same_v<T, bin1>) {
data_size = dim / 8 * row_num;
} else {
data_size = dim * row_num * sizeof(float);
data_size = dim * row_num * sizeof(T);
}
std::vector<uint8_t> raw_data;
raw_data.resize(data_size);
@ -954,7 +954,7 @@ VectorMemIndex<T>::LoadFromFileV2(const Config& config) {
LOG_INFO("load vector index done");
}
template class VectorMemIndex<float>;
template class VectorMemIndex<uint8_t>;
template class VectorMemIndex<bin1>;
template class VectorMemIndex<float16>;
template class VectorMemIndex<bfloat16>;

View File

@ -196,14 +196,14 @@ TEST(Float16, GetVector) {
auto vector = result.get()->mutable_vectors()->float16_vector();
EXPECT_TRUE(vector.size() == num_inserted * dim * sizeof(float16));
// EXPECT_TRUE(vector.size() == num_inserted * dim);
// for (size_t i = 0; i < num_inserted; ++i) {
// auto id = ids_ds->GetIds()[i];
// for (size_t j = 0; j < 128; ++j) {
// EXPECT_TRUE(vector[i * dim + j] ==
// fakevec[(id % per_batch) * dim + j]);
// }
// }
for (size_t i = 0; i < num_inserted; ++i) {
auto id = ids_ds->GetIds()[i];
for (size_t j = 0; j < 128; ++j) {
EXPECT_TRUE(
reinterpret_cast<float16*>(vector.data())[i * dim + j] ==
fakevec[(id % per_batch) * dim + j]);
}
}
}
}
@ -453,14 +453,14 @@ TEST(BFloat16, GetVector) {
auto vector = result.get()->mutable_vectors()->bfloat16_vector();
EXPECT_TRUE(vector.size() == num_inserted * dim * sizeof(bfloat16));
// EXPECT_TRUE(vector.size() == num_inserted * dim);
// for (size_t i = 0; i < num_inserted; ++i) {
// auto id = ids_ds->GetIds()[i];
// for (size_t j = 0; j < 128; ++j) {
// EXPECT_TRUE(vector[i * dim + j] ==
// fakevec[(id % per_batch) * dim + j]);
// }
// }
for (size_t i = 0; i < num_inserted; ++i) {
auto id = ids_ds->GetIds()[i];
for (size_t j = 0; j < 128; ++j) {
EXPECT_TRUE(
reinterpret_cast<bfloat16*>(vector.data())[i * dim + j] ==
fakevec[(id % per_batch) * dim + j]);
}
}
}
}