mirror of
https://gitee.com/milvus-io/milvus.git
synced 2025-12-07 09:38:39 +08:00
Get SIMD type used in faiss (#8849)
Signed-off-by: dragondriver <jiquan.long@zilliz.com>
This commit is contained in:
parent
546499ff63
commit
a10f421c14
@ -14,6 +14,7 @@
|
|||||||
#include "knowhere/archive/KnowhereConfig.h"
|
#include "knowhere/archive/KnowhereConfig.h"
|
||||||
#include "easyloggingpp/easylogging++.h"
|
#include "easyloggingpp/easylogging++.h"
|
||||||
#include "ConfigKnowhere.h"
|
#include "ConfigKnowhere.h"
|
||||||
|
#include "faiss/FaissHook.h"
|
||||||
|
|
||||||
namespace milvus {
|
namespace milvus {
|
||||||
namespace config {
|
namespace config {
|
||||||
@ -36,7 +37,7 @@ KnowhereInitImpl() {
|
|||||||
std::call_once(init_knowhere_once_, init);
|
std::call_once(init_knowhere_once_, init);
|
||||||
}
|
}
|
||||||
|
|
||||||
void
|
std::string
|
||||||
KnowhereSetSimdType(const char* value) {
|
KnowhereSetSimdType(const char* value) {
|
||||||
milvus::engine::KnowhereConfig::SimdType simd_type;
|
milvus::engine::KnowhereConfig::SimdType simd_type;
|
||||||
if (strcmp(value, "auto") == 0) {
|
if (strcmp(value, "auto") == 0) {
|
||||||
@ -50,7 +51,7 @@ KnowhereSetSimdType(const char* value) {
|
|||||||
} else {
|
} else {
|
||||||
PanicInfo("invalid SIMD type: " + std::string(value));
|
PanicInfo("invalid SIMD type: " + std::string(value));
|
||||||
}
|
}
|
||||||
milvus::engine::KnowhereConfig::SetSimdType(simd_type);
|
return milvus::engine::KnowhereConfig::SetSimdType(simd_type);
|
||||||
}
|
}
|
||||||
|
|
||||||
} // namespace config
|
} // namespace config
|
||||||
|
|||||||
@ -17,7 +17,7 @@ namespace milvus::config {
|
|||||||
void
|
void
|
||||||
KnowhereInitImpl();
|
KnowhereInitImpl();
|
||||||
|
|
||||||
void
|
std::string
|
||||||
KnowhereSetSimdType(const char*);
|
KnowhereSetSimdType(const char*);
|
||||||
|
|
||||||
} // namespace milvus::config
|
} // namespace milvus::config
|
||||||
|
|||||||
@ -27,6 +27,7 @@
|
|||||||
#include "utils/ConfigUtils.h"
|
#include "utils/ConfigUtils.h"
|
||||||
#include "utils/Error.h"
|
#include "utils/Error.h"
|
||||||
#include "utils/Log.h"
|
#include "utils/Log.h"
|
||||||
|
#include "index/knowhere/knowhere/common/Exception.h"
|
||||||
|
|
||||||
#include <string>
|
#include <string>
|
||||||
#include <vector>
|
#include <vector>
|
||||||
@ -36,7 +37,7 @@ namespace engine {
|
|||||||
|
|
||||||
constexpr int64_t M_BYTE = 1024 * 1024;
|
constexpr int64_t M_BYTE = 1024 * 1024;
|
||||||
|
|
||||||
Status
|
std::string
|
||||||
KnowhereConfig::SetSimdType(const SimdType simd_type) {
|
KnowhereConfig::SetSimdType(const SimdType simd_type) {
|
||||||
if (simd_type == SimdType::AVX512) {
|
if (simd_type == SimdType::AVX512) {
|
||||||
faiss::faiss_use_avx512 = true;
|
faiss::faiss_use_avx512 = true;
|
||||||
@ -58,12 +59,11 @@ KnowhereConfig::SetSimdType(const SimdType simd_type) {
|
|||||||
|
|
||||||
std::string cpu_flag;
|
std::string cpu_flag;
|
||||||
if (faiss::hook_init(cpu_flag)) {
|
if (faiss::hook_init(cpu_flag)) {
|
||||||
std::cout << "FAISS hook " << cpu_flag << std::endl;
|
|
||||||
LOG_KNOWHERE_DEBUG_ << "FAISS hook " << cpu_flag;
|
LOG_KNOWHERE_DEBUG_ << "FAISS hook " << cpu_flag;
|
||||||
return Status::OK();
|
return cpu_flag;
|
||||||
}
|
}
|
||||||
|
|
||||||
return Status(KNOWHERE_UNEXPECTED_ERROR, "FAISS hook fail, CPU not supported!");
|
KNOWHERE_THROW_MSG("FAISS hook fail, CPU not supported!");
|
||||||
}
|
}
|
||||||
|
|
||||||
void
|
void
|
||||||
|
|||||||
@ -12,6 +12,7 @@
|
|||||||
#pragma once
|
#pragma once
|
||||||
|
|
||||||
#include <vector>
|
#include <vector>
|
||||||
|
#include <string>
|
||||||
|
|
||||||
#include "utils/Status.h"
|
#include "utils/Status.h"
|
||||||
|
|
||||||
@ -30,7 +31,7 @@ class KnowhereConfig {
|
|||||||
AVX512, // only enable AVX512
|
AVX512, // only enable AVX512
|
||||||
};
|
};
|
||||||
|
|
||||||
static Status
|
static std::string
|
||||||
SetSimdType(const SimdType simd_type);
|
SetSimdType(const SimdType simd_type);
|
||||||
|
|
||||||
/**
|
/**
|
||||||
|
|||||||
@ -31,9 +31,9 @@ class KnowhereException : public std::exception {
|
|||||||
|
|
||||||
#define KNOHWERE_ERROR_MSG(MSG) printf("%s", KnowhereException(MSG, __PRETTY_FUNCTION__, __FILE__, __LINE__).what())
|
#define KNOHWERE_ERROR_MSG(MSG) printf("%s", KnowhereException(MSG, __PRETTY_FUNCTION__, __FILE__, __LINE__).what())
|
||||||
|
|
||||||
#define KNOWHERE_THROW_MSG(MSG) \
|
#define KNOWHERE_THROW_MSG(MSG) \
|
||||||
do { \
|
do { \
|
||||||
throw KnowhereException(MSG, __PRETTY_FUNCTION__, __FILE__, __LINE__); \
|
throw milvus::knowhere::KnowhereException(MSG, __PRETTY_FUNCTION__, __FILE__, __LINE__); \
|
||||||
} while (false)
|
} while (false)
|
||||||
|
|
||||||
#define KNOHERE_THROW_FORMAT(FMT, ...) \
|
#define KNOHERE_THROW_FORMAT(FMT, ...) \
|
||||||
|
|||||||
@ -9,6 +9,7 @@
|
|||||||
// is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express
|
// is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express
|
||||||
// or implied. See the License for the specific language governing permissions and limitations under the License
|
// or implied. See the License for the specific language governing permissions and limitations under the License
|
||||||
|
|
||||||
|
#include <string.h>
|
||||||
#include "config/ConfigKnowhere.h"
|
#include "config/ConfigKnowhere.h"
|
||||||
#include "indexbuilder/init_c.h"
|
#include "indexbuilder/init_c.h"
|
||||||
|
|
||||||
@ -17,7 +18,12 @@ IndexBuilderInit() {
|
|||||||
milvus::config::KnowhereInitImpl();
|
milvus::config::KnowhereInitImpl();
|
||||||
}
|
}
|
||||||
|
|
||||||
void
|
// return value must be freed by the caller
|
||||||
|
char*
|
||||||
IndexBuilderSetSimdType(const char* value) {
|
IndexBuilderSetSimdType(const char* value) {
|
||||||
milvus::config::KnowhereSetSimdType(value);
|
auto real_type = milvus::config::KnowhereSetSimdType(value);
|
||||||
|
char* ret = reinterpret_cast<char*>(malloc(real_type.length() + 1));
|
||||||
|
memcpy(ret, real_type.c_str(), real_type.length());
|
||||||
|
ret[real_type.length()] = 0;
|
||||||
|
return ret;
|
||||||
}
|
}
|
||||||
|
|||||||
@ -18,7 +18,8 @@ extern "C" {
|
|||||||
void
|
void
|
||||||
IndexBuilderInit();
|
IndexBuilderInit();
|
||||||
|
|
||||||
void
|
// return value must be freed by the caller
|
||||||
|
char*
|
||||||
IndexBuilderSetSimdType(const char*);
|
IndexBuilderSetSimdType(const char*);
|
||||||
|
|
||||||
#ifdef __cplusplus
|
#ifdef __cplusplus
|
||||||
|
|||||||
@ -29,10 +29,15 @@ SegcoreSetChunkRows(const int64_t value) {
|
|||||||
LOG_SEGCORE_DEBUG_ << "set config chunk_size: " << config.get_chunk_rows();
|
LOG_SEGCORE_DEBUG_ << "set config chunk_size: " << config.get_chunk_rows();
|
||||||
}
|
}
|
||||||
|
|
||||||
extern "C" void
|
// return value must be freed by the caller
|
||||||
|
extern "C" char*
|
||||||
SegcoreSetSimdType(const char* value) {
|
SegcoreSetSimdType(const char* value) {
|
||||||
milvus::config::KnowhereSetSimdType(value);
|
|
||||||
LOG_SEGCORE_DEBUG_ << "set config simd_type: " << value;
|
LOG_SEGCORE_DEBUG_ << "set config simd_type: " << value;
|
||||||
|
auto real_type = milvus::config::KnowhereSetSimdType(value);
|
||||||
|
char* ret = reinterpret_cast<char*>(malloc(real_type.length() + 1));
|
||||||
|
memcpy(ret, real_type.c_str(), real_type.length());
|
||||||
|
ret[real_type.length()] = 0;
|
||||||
|
return ret;
|
||||||
}
|
}
|
||||||
|
|
||||||
} // namespace milvus::segcore
|
} // namespace milvus::segcore
|
||||||
|
|||||||
@ -21,7 +21,8 @@ SegcoreInit();
|
|||||||
void
|
void
|
||||||
SegcoreSetChunkRows(const int64_t);
|
SegcoreSetChunkRows(const int64_t);
|
||||||
|
|
||||||
void
|
// return value must be freed by the caller
|
||||||
|
char*
|
||||||
SegcoreSetSimdType(const char*);
|
SegcoreSetSimdType(const char*);
|
||||||
|
|
||||||
#ifdef __cplusplus
|
#ifdef __cplusplus
|
||||||
|
|||||||
@ -114,9 +114,11 @@ func (i *IndexNode) Register() error {
|
|||||||
func (i *IndexNode) initKnowhere() {
|
func (i *IndexNode) initKnowhere() {
|
||||||
C.IndexBuilderInit()
|
C.IndexBuilderInit()
|
||||||
|
|
||||||
// override segcore SIMD type
|
// override index builder SIMD type
|
||||||
cSimdType := C.CString(Params.SimdType)
|
cSimdType := C.CString(Params.SimdType)
|
||||||
C.IndexBuilderSetSimdType(cSimdType)
|
cRealSimdType := C.IndexBuilderSetSimdType(cSimdType)
|
||||||
|
Params.SimdType = C.GoString(cRealSimdType)
|
||||||
|
C.free(unsafe.Pointer(cRealSimdType))
|
||||||
C.free(unsafe.Pointer(cSimdType))
|
C.free(unsafe.Pointer(cSimdType))
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@ -124,7 +124,9 @@ func (node *QueryNode) InitSegcore() {
|
|||||||
|
|
||||||
// override segcore SIMD type
|
// override segcore SIMD type
|
||||||
cSimdType := C.CString(Params.SimdType)
|
cSimdType := C.CString(Params.SimdType)
|
||||||
C.SegcoreSetSimdType(cSimdType)
|
cRealSimdType := C.SegcoreSetSimdType(cSimdType)
|
||||||
|
Params.SimdType = C.GoString(cRealSimdType)
|
||||||
|
C.free(unsafe.Pointer(cRealSimdType))
|
||||||
C.free(unsafe.Pointer(cSimdType))
|
C.free(unsafe.Pointer(cSimdType))
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user