mirror of
https://gitee.com/milvus-io/milvus.git
synced 2025-12-30 07:25:37 +08:00
optimize similarity template (#2227)
* optimize similarity template Signed-off-by: yudong.cai <yudong.cai@zilliz.com> * code opt Signed-off-by: yudong.cai <yudong.cai@zilliz.com>
This commit is contained in:
parent
ac8e9ff020
commit
f218217064
@ -313,70 +313,18 @@ struct SimilarityL2<1> {
|
||||
|
||||
/* as same as SimilarityL2<1>, let build pass */
|
||||
template<>
|
||||
struct SimilarityL2<8> {
|
||||
struct SimilarityL2<8> : SimilarityL2<1> {
|
||||
static constexpr int simdwidth = 1;
|
||||
static constexpr MetricType metric_type = METRIC_L2;
|
||||
|
||||
const float *y, *yi;
|
||||
|
||||
explicit SimilarityL2 (const float * y): y(y) {}
|
||||
|
||||
/******* scalar accumulator *******/
|
||||
|
||||
float accu;
|
||||
|
||||
void begin () {
|
||||
accu = 0;
|
||||
yi = y;
|
||||
}
|
||||
|
||||
void add_component (float x) {
|
||||
float tmp = *yi++ - x;
|
||||
accu += tmp * tmp;
|
||||
}
|
||||
|
||||
void add_component_2 (float x1, float x2) {
|
||||
float tmp = x1 - x2;
|
||||
accu += tmp * tmp;
|
||||
}
|
||||
|
||||
float result () {
|
||||
return accu;
|
||||
}
|
||||
explicit SimilarityL2 (const float * y) : SimilarityL2<1>(y) {}
|
||||
};
|
||||
|
||||
/* as same as SimilarityL2<1>, let build pass */
|
||||
template<>
|
||||
struct SimilarityL2<16> {
|
||||
struct SimilarityL2<16> : SimilarityL2<1> {
|
||||
static constexpr int simdwidth = 1;
|
||||
static constexpr MetricType metric_type = METRIC_L2;
|
||||
|
||||
const float *y, *yi;
|
||||
|
||||
explicit SimilarityL2 (const float * y): y(y) {}
|
||||
|
||||
/******* scalar accumulator *******/
|
||||
|
||||
float accu;
|
||||
|
||||
void begin () {
|
||||
accu = 0;
|
||||
yi = y;
|
||||
}
|
||||
|
||||
void add_component (float x) {
|
||||
float tmp = *yi++ - x;
|
||||
accu += tmp * tmp;
|
||||
}
|
||||
|
||||
void add_component_2 (float x1, float x2) {
|
||||
float tmp = x1 - x2;
|
||||
accu += tmp * tmp;
|
||||
}
|
||||
|
||||
float result () {
|
||||
return accu;
|
||||
}
|
||||
explicit SimilarityL2 (const float * y) : SimilarityL2<1>(y) {}
|
||||
};
|
||||
|
||||
|
||||
@ -414,62 +362,18 @@ struct SimilarityIP<1> {
|
||||
|
||||
/* as same as SimilarityIP<1>, let build pass */
|
||||
template<>
|
||||
struct SimilarityIP<8> {
|
||||
struct SimilarityIP<8> : SimilarityIP<1> {
|
||||
static constexpr int simdwidth = 1;
|
||||
static constexpr MetricType metric_type = METRIC_INNER_PRODUCT;
|
||||
const float *y, *yi;
|
||||
|
||||
float accu;
|
||||
|
||||
explicit SimilarityIP (const float * y):
|
||||
y (y) {}
|
||||
|
||||
void begin () {
|
||||
accu = 0;
|
||||
yi = y;
|
||||
}
|
||||
|
||||
void add_component (float x) {
|
||||
accu += *yi++ * x;
|
||||
}
|
||||
|
||||
void add_component_2 (float x1, float x2) {
|
||||
accu += x1 * x2;
|
||||
}
|
||||
|
||||
float result () {
|
||||
return accu;
|
||||
}
|
||||
explicit SimilarityIP (const float * y) : SimilarityIP<1>(y) {}
|
||||
};
|
||||
|
||||
/* as same as SimilarityIP<1>, let build pass */
|
||||
template<>
|
||||
struct SimilarityIP<16> {
|
||||
struct SimilarityIP<16> : SimilarityIP<1> {
|
||||
static constexpr int simdwidth = 1;
|
||||
static constexpr MetricType metric_type = METRIC_INNER_PRODUCT;
|
||||
const float *y, *yi;
|
||||
|
||||
float accu;
|
||||
|
||||
explicit SimilarityIP (const float * y):
|
||||
y (y) {}
|
||||
|
||||
void begin () {
|
||||
accu = 0;
|
||||
yi = y;
|
||||
}
|
||||
|
||||
void add_component (float x) {
|
||||
accu += *yi++ * x;
|
||||
}
|
||||
|
||||
void add_component_2 (float x1, float x2) {
|
||||
accu += x1 * x2;
|
||||
}
|
||||
|
||||
float result () {
|
||||
return accu;
|
||||
}
|
||||
explicit SimilarityIP (const float * y) : SimilarityIP<1>(y) {}
|
||||
};
|
||||
|
||||
|
||||
|
||||
@ -482,40 +482,10 @@ struct SimilarityL2_avx<8> {
|
||||
|
||||
/* as same as SimilarityL2<8>, let build pass */
|
||||
template<>
|
||||
struct SimilarityL2_avx<16> {
|
||||
struct SimilarityL2_avx<16> : SimilarityL2_avx<8>{
|
||||
static constexpr int simdwidth = 8;
|
||||
static constexpr MetricType metric_type = METRIC_L2;
|
||||
|
||||
const float *y, *yi;
|
||||
|
||||
explicit SimilarityL2_avx (const float * y): y(y) {}
|
||||
__m256 accu8;
|
||||
|
||||
void begin_8 () {
|
||||
accu8 = _mm256_setzero_ps();
|
||||
yi = y;
|
||||
}
|
||||
|
||||
void add_8_components (__m256 x) {
|
||||
__m256 yiv = _mm256_loadu_ps (yi);
|
||||
yi += 8;
|
||||
__m256 tmp = yiv - x;
|
||||
accu8 += tmp * tmp;
|
||||
}
|
||||
|
||||
void add_8_components_2 (__m256 x, __m256 y) {
|
||||
__m256 tmp = y - x;
|
||||
accu8 += tmp * tmp;
|
||||
}
|
||||
|
||||
float result_8 () {
|
||||
__m256 sum = _mm256_hadd_ps(accu8, accu8);
|
||||
__m256 sum2 = _mm256_hadd_ps(sum, sum);
|
||||
// now add the 0th and 4th component
|
||||
return
|
||||
_mm_cvtss_f32 (_mm256_castps256_ps128(sum2)) +
|
||||
_mm_cvtss_f32 (_mm256_extractf128_ps(sum2, 1));
|
||||
}
|
||||
explicit SimilarityL2_avx (const float * y) : SimilarityL2_avx<8>(y) {}
|
||||
};
|
||||
#endif
|
||||
|
||||
@ -596,42 +566,10 @@ struct SimilarityIP_avx<8> {
|
||||
|
||||
/* as same as SimilarityIP<8>, let build pass */
|
||||
template<>
|
||||
struct SimilarityIP_avx<16> {
|
||||
struct SimilarityIP_avx<16> : SimilarityIP_avx<8> {
|
||||
static constexpr int simdwidth = 8;
|
||||
static constexpr MetricType metric_type = METRIC_INNER_PRODUCT;
|
||||
|
||||
const float *y, *yi;
|
||||
|
||||
float accu;
|
||||
|
||||
explicit SimilarityIP_avx (const float * y):
|
||||
y (y) {}
|
||||
|
||||
__m256 accu8;
|
||||
|
||||
void begin_8 () {
|
||||
accu8 = _mm256_setzero_ps();
|
||||
yi = y;
|
||||
}
|
||||
|
||||
void add_8_components (__m256 x) {
|
||||
__m256 yiv = _mm256_loadu_ps (yi);
|
||||
yi += 8;
|
||||
accu8 += yiv * x;
|
||||
}
|
||||
|
||||
void add_8_components_2 (__m256 x1, __m256 x2) {
|
||||
accu8 += x1 * x2;
|
||||
}
|
||||
|
||||
float result_8 () {
|
||||
__m256 sum = _mm256_hadd_ps(accu8, accu8);
|
||||
__m256 sum2 = _mm256_hadd_ps(sum, sum);
|
||||
// now add the 0th and 4th component
|
||||
return
|
||||
_mm_cvtss_f32 (_mm256_castps256_ps128(sum2)) +
|
||||
_mm_cvtss_f32 (_mm256_extractf128_ps(sum2, 1));
|
||||
}
|
||||
explicit SimilarityIP_avx (const float * y) : SimilarityIP_avx<8>(y) {}
|
||||
};
|
||||
#endif
|
||||
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user