diff --git a/configs/advanced/query_node.yaml b/configs/advanced/query_node.yaml index 8ae2e7060d..1f9a773132 100644 --- a/configs/advanced/query_node.yaml +++ b/configs/advanced/query_node.yaml @@ -27,4 +27,5 @@ queryNode: recvBufSize: 64 segcore: - chunkSize: 32768 # 32M \ No newline at end of file + chunkSize: 32768 # 32M + simdType: auto # auto, avx512, avx2, sse diff --git a/internal/core/src/segcore/segcore_init_c.cpp b/internal/core/src/segcore/segcore_init_c.cpp index 3c2dacd337..af36fcb3df 100644 --- a/internal/core/src/segcore/segcore_init_c.cpp +++ b/internal/core/src/segcore/segcore_init_c.cpp @@ -42,3 +42,21 @@ SegcoreSetChunkSize(const int64_t value) { config.set_size_per_chunk(value); std::cout << "set config chunk_size: " << config.get_size_per_chunk() << std::endl; } + +extern "C" void +SegcoreSetSimdType(const char* value) { + milvus::engine::KnowhereConfig::SimdType simd_type; + if (strcmp(value, "auto") == 0) { + simd_type = milvus::engine::KnowhereConfig::SimdType::AUTO; + } else if (strcmp(value, "avx512") == 0) { + simd_type = milvus::engine::KnowhereConfig::SimdType::AVX512; + } else if (strcmp(value, "avx2") == 0) { + simd_type = milvus::engine::KnowhereConfig::SimdType::AVX2; + } else if (strcmp(value, "sse") == 0) { + simd_type = milvus::engine::KnowhereConfig::SimdType::SSE; + } else { + PanicInfo("invalid SIMD type: " + std::string(value)); + } + milvus::engine::KnowhereConfig::SetSimdType(simd_type); + std::cout << "set config simd_type: " << int(simd_type) << std::endl; +} diff --git a/internal/core/src/segcore/segcore_init_c.h b/internal/core/src/segcore/segcore_init_c.h index bd8bb656ba..985faefa8f 100644 --- a/internal/core/src/segcore/segcore_init_c.h +++ b/internal/core/src/segcore/segcore_init_c.h @@ -21,6 +21,9 @@ SegcoreInit(); void SegcoreSetChunkSize(const int64_t); +void +SegcoreSetSimdType(const char*); + #ifdef __cplusplus } #endif diff --git a/internal/core/unittest/test_init.cpp b/internal/core/unittest/test_init.cpp index 73ee67855c..e93cce6995 100644 --- a/internal/core/unittest/test_init.cpp +++ b/internal/core/unittest/test_init.cpp @@ -21,4 +21,5 @@ TEST(Init, Naive) { using namespace milvus::segcore; SegcoreInit(); SegcoreSetChunkSize(32768); -} \ No newline at end of file + SegcoreSetSimdType("auto"); +} diff --git a/internal/querynode/param_table.go b/internal/querynode/param_table.go index b09a3fabb7..6a444aa202 100644 --- a/internal/querynode/param_table.go +++ b/internal/querynode/param_table.go @@ -70,6 +70,7 @@ type ParamTable struct { // segcore ChunkSize int64 + SimdType string Log log.Config } @@ -115,6 +116,7 @@ func (p *ParamTable) Init() { p.initStatsChannelName() p.initSegcoreChunkSize() + p.initSegcoreSimdType() p.initLogCfg() }) @@ -263,6 +265,14 @@ func (p *ParamTable) initSegcoreChunkSize() { p.ChunkSize = p.ParseInt64("queryNode.segcore.chunkSize") } +func (p *ParamTable) initSegcoreSimdType() { + simdType, err := p.Load("queryNode.segcore.simdType") + if err != nil { + panic(err) + } + p.SimdType = simdType +} + func (p *ParamTable) initLogCfg() { p.Log = log.Config{} format, err := p.Load("log.format") diff --git a/internal/querynode/query_node.go b/internal/querynode/query_node.go index 57c0bb565b..02cc25ffd9 100644 --- a/internal/querynode/query_node.go +++ b/internal/querynode/query_node.go @@ -29,6 +29,7 @@ import ( "errors" "strconv" "sync/atomic" + "unsafe" "go.uber.org/zap" @@ -104,6 +105,11 @@ func (node *QueryNode) InitSegcore() { // override segcore chunk size cChunkSize := C.int64_t(Params.ChunkSize) C.SegcoreSetChunkSize(cChunkSize) + + // override segcore SIMD type + cSimdType := C.CString(Params.SimdType) + C.SegcoreSetSimdType(cSimdType) + C.free(unsafe.Pointer(cSimdType)) } func (node *QueryNode) Init() error {