fix: SkipIndex cause segment fault (#35907)

issue: #35882

Signed-off-by: chyezh <chyezh@outlook.com>
This commit is contained in:
Zhen Ye 2024-09-03 17:15:03 +08:00 committed by GitHub
parent f068729a26
commit f68df9a11e
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 37 additions and 34 deletions

View File

@ -119,8 +119,8 @@ SkipIndex::LoadString(milvus::FieldId field_id,
auto chunkMetrics = std::make_unique<FieldChunkMetrics>(); auto chunkMetrics = std::make_unique<FieldChunkMetrics>();
if (num_rows > 0) { if (num_rows > 0) {
auto info = ProcessStringFieldMetrics(var_column); auto info = ProcessStringFieldMetrics(var_column);
chunkMetrics->min_ = Metrics(info.min_); chunkMetrics->min_ = Metrics(std::move(info.min_));
chunkMetrics->max_ = Metrics(info.max_); chunkMetrics->max_ = Metrics(std::move(info.max_));
chunkMetrics->null_count_ = info.null_count_; chunkMetrics->null_count_ = info.null_count_;
} }

View File

@ -19,13 +19,19 @@
namespace milvus { namespace milvus {
using Metrics = std:: using Metrics =
variant<int8_t, int16_t, int32_t, int64_t, float, double, std::string_view>; std::variant<int8_t, int16_t, int32_t, int64_t, float, double, std::string>;
// MetricsDataType is used to avoid copy when get min/max value from FieldChunkMetrics
template <typename T> template <typename T>
using MetricsDataType = using MetricsDataType =
std::conditional_t<std::is_same_v<T, std::string>, std::string_view, T>; std::conditional_t<std::is_same_v<T, std::string>, std::string_view, T>;
// ReverseMetricsDataType is used to avoid copy when get min/max value from FieldChunkMetrics
template <typename T>
using ReverseMetricsDataType =
std::conditional_t<std::is_same_v<T, std::string_view>, std::string, T>;
struct FieldChunkMetrics { struct FieldChunkMetrics {
Metrics min_; Metrics min_;
Metrics max_; Metrics max_;
@ -33,6 +39,22 @@ struct FieldChunkMetrics {
int64_t null_count_; int64_t null_count_;
FieldChunkMetrics() : hasValue_(false){}; FieldChunkMetrics() : hasValue_(false){};
template <typename T>
std::pair<MetricsDataType<T>, MetricsDataType<T>>
GetMinMax() const {
AssertInfo(hasValue_,
"GetMinMax should never be called when hasValue_ is false");
MetricsDataType<T> lower_bound;
MetricsDataType<T> upper_bound;
try {
lower_bound = std::get<ReverseMetricsDataType<T>>(min_);
upper_bound = std::get<ReverseMetricsDataType<T>>(max_);
} catch (const std::bad_variant_access& e) {
return {};
}
return {lower_bound, upper_bound};
}
}; };
class SkipIndex { class SkipIndex {
@ -99,22 +121,6 @@ class SkipIndex {
static constexpr bool value = isAllowedType && !isDisabledType; static constexpr bool value = isAllowedType && !isDisabledType;
}; };
template <typename T>
std::pair<MetricsDataType<T>, MetricsDataType<T>>
GetMinMax(const FieldChunkMetrics& field_chunk_metrics) const {
MetricsDataType<T> lower_bound;
MetricsDataType<T> upper_bound;
try {
lower_bound =
std::get<MetricsDataType<T>>(field_chunk_metrics.min_);
upper_bound =
std::get<MetricsDataType<T>>(field_chunk_metrics.max_);
} catch (const std::bad_variant_access&) {
return {};
}
return {lower_bound, upper_bound};
}
template <typename T> template <typename T>
std::enable_if_t<SkipIndex::IsAllowedType<T>::value, bool> std::enable_if_t<SkipIndex::IsAllowedType<T>::value, bool>
MinMaxUnaryFilter(const FieldChunkMetrics& field_chunk_metrics, MinMaxUnaryFilter(const FieldChunkMetrics& field_chunk_metrics,
@ -123,13 +129,12 @@ class SkipIndex {
if (!field_chunk_metrics.hasValue_) { if (!field_chunk_metrics.hasValue_) {
return false; return false;
} }
std::pair<MetricsDataType<T>, MetricsDataType<T>> minMax = auto [lower_bound, upper_bound] = field_chunk_metrics.GetMinMax<T>();
GetMinMax<T>(field_chunk_metrics); if (lower_bound == MetricsDataType<T>() ||
if (minMax.first == MetricsDataType<T>() || upper_bound == MetricsDataType<T>()) {
minMax.second == MetricsDataType<T>()) {
return false; return false;
} }
return RangeShouldSkip<T>(val, minMax.first, minMax.second, op_type); return RangeShouldSkip<T>(val, lower_bound, upper_bound, op_type);
} }
template <typename T> template <typename T>
@ -150,15 +155,12 @@ class SkipIndex {
if (!field_chunk_metrics.hasValue_) { if (!field_chunk_metrics.hasValue_) {
return false; return false;
} }
std::pair<MetricsDataType<T>, MetricsDataType<T>> minMax = auto [lower_bound, upper_bound] = field_chunk_metrics.GetMinMax<T>();
GetMinMax<T>(field_chunk_metrics); if (lower_bound == MetricsDataType<T>() ||
if (minMax.first == MetricsDataType<T>() || upper_bound == MetricsDataType<T>()) {
minMax.second == MetricsDataType<T>()) {
return false; return false;
} }
bool should_skip = false; bool should_skip = false;
MetricsDataType<T> lower_bound = minMax.first;
MetricsDataType<T> upper_bound = minMax.second;
if (lower_inclusive && upper_inclusive) { if (lower_inclusive && upper_inclusive) {
should_skip = should_skip =
(lower_val > upper_bound) || (upper_val < lower_bound); (lower_val > upper_bound) || (upper_val < lower_bound);
@ -267,7 +269,7 @@ class SkipIndex {
return {minValue, maxValue, null_count}; return {minValue, maxValue, null_count};
} }
metricInfo<std::string_view> metricInfo<std::string>
ProcessStringFieldMetrics( ProcessStringFieldMetrics(
const milvus::VariableColumn<std::string>& var_column) { const milvus::VariableColumn<std::string>& var_column) {
int num_rows = var_column.NumRows(); int num_rows = var_column.NumRows();
@ -281,7 +283,7 @@ class SkipIndex {
break; break;
} }
if (start > num_rows - 1) { if (start > num_rows - 1) {
return {std::string_view(), std::string_view(), num_rows}; return {std::string(), std::string(), num_rows};
} }
std::string_view min_string = var_column.RawAt(start); std::string_view min_string = var_column.RawAt(start);
std::string_view max_string = var_column.RawAt(start); std::string_view max_string = var_column.RawAt(start);
@ -299,7 +301,8 @@ class SkipIndex {
max_string = val; max_string = val;
} }
} }
return {min_string, max_string, null_count}; // The field data may be released, so we need to copy the string to avoid invalid memory access.
return {std::string(min_string), std::string(max_string), null_count};
} }
private: private: