diff --git a/internal/storage/payload.go b/internal/storage/payload.go index 3a43d21802..b60bc686b8 100644 --- a/internal/storage/payload.go +++ b/internal/storage/payload.go @@ -21,8 +21,10 @@ package storage import "C" import ( "errors" + "fmt" "unsafe" + "github.com/milvus-io/milvus/internal/log" "github.com/milvus-io/milvus/internal/proto/commonpb" "github.com/milvus-io/milvus/internal/proto/schemapb" ) @@ -168,14 +170,7 @@ func (w *PayloadWriter) AddBoolToPayload(msgs []bool) error { cLength := C.int(length) status := C.AddBooleanToPayload(w.payloadWriterPtr, cMsgs, cLength) - - errCode := commonpb.ErrorCode(status.error_code) - if errCode != commonpb.ErrorCode_Success { - msg := C.GoString(status.error_msg) - defer C.free(unsafe.Pointer(status.error_msg)) - return errors.New(msg) - } - return nil + return HandleCStatus(&status, "AddBoolToPayload failed") } func (w *PayloadWriter) AddInt8ToPayload(msgs []int8) error { @@ -187,14 +182,7 @@ func (w *PayloadWriter) AddInt8ToPayload(msgs []int8) error { cLength := C.int(length) status := C.AddInt8ToPayload(w.payloadWriterPtr, cMsgs, cLength) - - errCode := commonpb.ErrorCode(status.error_code) - if errCode != commonpb.ErrorCode_Success { - msg := C.GoString(status.error_msg) - defer C.free(unsafe.Pointer(status.error_msg)) - return errors.New(msg) - } - return nil + return HandleCStatus(&status, "AddInt8ToPayload failed") } func (w *PayloadWriter) AddInt16ToPayload(msgs []int16) error { @@ -207,14 +195,7 @@ func (w *PayloadWriter) AddInt16ToPayload(msgs []int16) error { cLength := C.int(length) status := C.AddInt16ToPayload(w.payloadWriterPtr, cMsgs, cLength) - - errCode := commonpb.ErrorCode(status.error_code) - if errCode != commonpb.ErrorCode_Success { - msg := C.GoString(status.error_msg) - defer C.free(unsafe.Pointer(status.error_msg)) - return errors.New(msg) - } - return nil + return HandleCStatus(&status, "AddInt16ToPayload failed") } func (w *PayloadWriter) AddInt32ToPayload(msgs []int32) error { @@ -227,14 +208,7 @@ func (w *PayloadWriter) AddInt32ToPayload(msgs []int32) error { cLength := C.int(length) status := C.AddInt32ToPayload(w.payloadWriterPtr, cMsgs, cLength) - - errCode := commonpb.ErrorCode(status.error_code) - if errCode != commonpb.ErrorCode_Success { - msg := C.GoString(status.error_msg) - defer C.free(unsafe.Pointer(status.error_msg)) - return errors.New(msg) - } - return nil + return HandleCStatus(&status, "AddInt32ToPayload failed") } func (w *PayloadWriter) AddInt64ToPayload(msgs []int64) error { @@ -247,14 +221,7 @@ func (w *PayloadWriter) AddInt64ToPayload(msgs []int64) error { cLength := C.int(length) status := C.AddInt64ToPayload(w.payloadWriterPtr, cMsgs, cLength) - - errCode := commonpb.ErrorCode(status.error_code) - if errCode != commonpb.ErrorCode_Success { - msg := C.GoString(status.error_msg) - defer C.free(unsafe.Pointer(status.error_msg)) - return errors.New(msg) - } - return nil + return HandleCStatus(&status, "AddInt64ToPayload failed") } func (w *PayloadWriter) AddFloatToPayload(msgs []float32) error { @@ -267,14 +234,7 @@ func (w *PayloadWriter) AddFloatToPayload(msgs []float32) error { cLength := C.int(length) status := C.AddFloatToPayload(w.payloadWriterPtr, cMsgs, cLength) - - errCode := commonpb.ErrorCode(status.error_code) - if errCode != commonpb.ErrorCode_Success { - msg := C.GoString(status.error_msg) - defer C.free(unsafe.Pointer(status.error_msg)) - return errors.New(msg) - } - return nil + return HandleCStatus(&status, "AddFloatToPayload failed") } func (w *PayloadWriter) AddDoubleToPayload(msgs []float64) error { @@ -287,14 +247,7 @@ func (w *PayloadWriter) AddDoubleToPayload(msgs []float64) error { cLength := C.int(length) status := C.AddDoubleToPayload(w.payloadWriterPtr, cMsgs, cLength) - - errCode := commonpb.ErrorCode(status.error_code) - if errCode != commonpb.ErrorCode_Success { - msg := C.GoString(status.error_msg) - defer C.free(unsafe.Pointer(status.error_msg)) - return errors.New(msg) - } - return nil + return HandleCStatus(&status, "AddDoubleToPayload failed") } func (w *PayloadWriter) AddOneStringToPayload(msg string) error { @@ -307,15 +260,8 @@ func (w *PayloadWriter) AddOneStringToPayload(msg string) error { clength := C.int(length) defer C.free(unsafe.Pointer(cmsg)) - st := C.AddOneStringToPayload(w.payloadWriterPtr, cmsg, clength) - - errCode := commonpb.ErrorCode(st.error_code) - if errCode != commonpb.ErrorCode_Success { - msg := C.GoString(st.error_msg) - defer C.free(unsafe.Pointer(st.error_msg)) - return errors.New(msg) - } - return nil + status := C.AddOneStringToPayload(w.payloadWriterPtr, cmsg, clength) + return HandleCStatus(&status, "AddOneStringToPayload failed") } // dimension > 0 && (%8 == 0) @@ -332,14 +278,8 @@ func (w *PayloadWriter) AddBinaryVectorToPayload(binVec []byte, dim int) error { cDim := C.int(dim) cLength := C.int(length / (dim / 8)) - st := C.AddBinaryVectorToPayload(w.payloadWriterPtr, cBinVec, cDim, cLength) - errCode := commonpb.ErrorCode(st.error_code) - if errCode != commonpb.ErrorCode_Success { - msg := C.GoString(st.error_msg) - defer C.free(unsafe.Pointer(st.error_msg)) - return errors.New(msg) - } - return nil + status := C.AddBinaryVectorToPayload(w.payloadWriterPtr, cBinVec, cDim, cLength) + return HandleCStatus(&status, "AddBinaryVectorToPayload failed") } // dimension > 0 && (%8 == 0) @@ -356,25 +296,13 @@ func (w *PayloadWriter) AddFloatVectorToPayload(floatVec []float32, dim int) err cDim := C.int(dim) cLength := C.int(length / dim) - st := C.AddFloatVectorToPayload(w.payloadWriterPtr, cVec, cDim, cLength) - errCode := commonpb.ErrorCode(st.error_code) - if errCode != commonpb.ErrorCode_Success { - msg := C.GoString(st.error_msg) - defer C.free(unsafe.Pointer(st.error_msg)) - return errors.New(msg) - } - return nil + status := C.AddFloatVectorToPayload(w.payloadWriterPtr, cVec, cDim, cLength) + return HandleCStatus(&status, "AddFloatVectorToPayload failed") } func (w *PayloadWriter) FinishPayloadWriter() error { - st := C.FinishPayloadWriter(w.payloadWriterPtr) - errCode := commonpb.ErrorCode(st.error_code) - if errCode != commonpb.ErrorCode_Success { - msg := C.GoString(st.error_msg) - defer C.free(unsafe.Pointer(st.error_msg)) - return errors.New(msg) - } - return nil + status := C.FinishPayloadWriter(w.payloadWriterPtr) + return HandleCStatus(&status, "FinishPayloadWriter failed") } func (w *PayloadWriter) GetPayloadBufferFromWriter() ([]byte, error) { @@ -395,14 +323,8 @@ func (w *PayloadWriter) GetPayloadLengthFromWriter() (int, error) { } func (w *PayloadWriter) ReleasePayloadWriter() error { - st := C.ReleasePayloadWriter(w.payloadWriterPtr) - errCode := commonpb.ErrorCode(st.error_code) - if errCode != commonpb.ErrorCode_Success { - msg := C.GoString(st.error_msg) - defer C.free(unsafe.Pointer(st.error_msg)) - return errors.New(msg) - } - return nil + status := C.ReleasePayloadWriter(w.payloadWriterPtr) + return HandleCStatus(&status, "ReleasePayloadWriter failed") } func (w *PayloadWriter) Close() error { @@ -472,14 +394,8 @@ func (r *PayloadReader) GetDataFromPayload(idx ...int) (interface{}, int, error) } func (r *PayloadReader) ReleasePayloadReader() error { - st := C.ReleasePayloadReader(r.payloadReaderPtr) - errCode := commonpb.ErrorCode(st.error_code) - if errCode != commonpb.ErrorCode_Success { - msg := C.GoString(st.error_msg) - defer C.free(unsafe.Pointer(st.error_msg)) - return errors.New(msg) - } - return nil + status := C.ReleasePayloadReader(r.payloadReaderPtr) + return HandleCStatus(&status, "ReleasePayloadReader failed") } func (r *PayloadReader) GetBoolFromPayload() ([]bool, error) { @@ -490,12 +406,9 @@ func (r *PayloadReader) GetBoolFromPayload() ([]bool, error) { var cMsg *C.bool var cSize C.int - st := C.GetBoolFromPayload(r.payloadReaderPtr, &cMsg, &cSize) - errCode := commonpb.ErrorCode(st.error_code) - if errCode != commonpb.ErrorCode_Success { - msg := C.GoString(st.error_msg) - defer C.free(unsafe.Pointer(st.error_msg)) - return nil, errors.New(msg) + status := C.GetBoolFromPayload(r.payloadReaderPtr, &cMsg, &cSize) + if err := HandleCStatus(&status, "GetBoolFromPayload failed"); err != nil { + return nil, err } slice := (*[1 << 28]bool)(unsafe.Pointer(cMsg))[:cSize:cSize] @@ -510,12 +423,9 @@ func (r *PayloadReader) GetInt8FromPayload() ([]int8, error) { var cMsg *C.int8_t var cSize C.int - st := C.GetInt8FromPayload(r.payloadReaderPtr, &cMsg, &cSize) - errCode := commonpb.ErrorCode(st.error_code) - if errCode != commonpb.ErrorCode_Success { - msg := C.GoString(st.error_msg) - defer C.free(unsafe.Pointer(st.error_msg)) - return nil, errors.New(msg) + status := C.GetInt8FromPayload(r.payloadReaderPtr, &cMsg, &cSize) + if err := HandleCStatus(&status, "GetInt8FromPayload failed"); err != nil { + return nil, err } slice := (*[1 << 28]int8)(unsafe.Pointer(cMsg))[:cSize:cSize] @@ -530,12 +440,9 @@ func (r *PayloadReader) GetInt16FromPayload() ([]int16, error) { var cMsg *C.int16_t var cSize C.int - st := C.GetInt16FromPayload(r.payloadReaderPtr, &cMsg, &cSize) - errCode := commonpb.ErrorCode(st.error_code) - if errCode != commonpb.ErrorCode_Success { - msg := C.GoString(st.error_msg) - defer C.free(unsafe.Pointer(st.error_msg)) - return nil, errors.New(msg) + status := C.GetInt16FromPayload(r.payloadReaderPtr, &cMsg, &cSize) + if err := HandleCStatus(&status, "GetInt16FromPayload failed"); err != nil { + return nil, err } slice := (*[1 << 28]int16)(unsafe.Pointer(cMsg))[:cSize:cSize] @@ -550,12 +457,9 @@ func (r *PayloadReader) GetInt32FromPayload() ([]int32, error) { var cMsg *C.int32_t var cSize C.int - st := C.GetInt32FromPayload(r.payloadReaderPtr, &cMsg, &cSize) - errCode := commonpb.ErrorCode(st.error_code) - if errCode != commonpb.ErrorCode_Success { - msg := C.GoString(st.error_msg) - defer C.free(unsafe.Pointer(st.error_msg)) - return nil, errors.New(msg) + status := C.GetInt32FromPayload(r.payloadReaderPtr, &cMsg, &cSize) + if err := HandleCStatus(&status, "GetInt32FromPayload failed"); err != nil { + return nil, err } slice := (*[1 << 28]int32)(unsafe.Pointer(cMsg))[:cSize:cSize] @@ -570,12 +474,9 @@ func (r *PayloadReader) GetInt64FromPayload() ([]int64, error) { var cMsg *C.int64_t var cSize C.int - st := C.GetInt64FromPayload(r.payloadReaderPtr, &cMsg, &cSize) - errCode := commonpb.ErrorCode(st.error_code) - if errCode != commonpb.ErrorCode_Success { - msg := C.GoString(st.error_msg) - defer C.free(unsafe.Pointer(st.error_msg)) - return nil, errors.New(msg) + status := C.GetInt64FromPayload(r.payloadReaderPtr, &cMsg, &cSize) + if err := HandleCStatus(&status, "GetInt64FromPayload failed"); err != nil { + return nil, err } slice := (*[1 << 28]int64)(unsafe.Pointer(cMsg))[:cSize:cSize] @@ -590,12 +491,9 @@ func (r *PayloadReader) GetFloatFromPayload() ([]float32, error) { var cMsg *C.float var cSize C.int - st := C.GetFloatFromPayload(r.payloadReaderPtr, &cMsg, &cSize) - errCode := commonpb.ErrorCode(st.error_code) - if errCode != commonpb.ErrorCode_Success { - msg := C.GoString(st.error_msg) - defer C.free(unsafe.Pointer(st.error_msg)) - return nil, errors.New(msg) + status := C.GetFloatFromPayload(r.payloadReaderPtr, &cMsg, &cSize) + if err := HandleCStatus(&status, "GetFloatFromPayload failed"); err != nil { + return nil, err } slice := (*[1 << 28]float32)(unsafe.Pointer(cMsg))[:cSize:cSize] @@ -610,12 +508,9 @@ func (r *PayloadReader) GetDoubleFromPayload() ([]float64, error) { var cMsg *C.double var cSize C.int - st := C.GetDoubleFromPayload(r.payloadReaderPtr, &cMsg, &cSize) - errCode := commonpb.ErrorCode(st.error_code) - if errCode != commonpb.ErrorCode_Success { - msg := C.GoString(st.error_msg) - defer C.free(unsafe.Pointer(st.error_msg)) - return nil, errors.New(msg) + status := C.GetDoubleFromPayload(r.payloadReaderPtr, &cMsg, &cSize) + if err := HandleCStatus(&status, "GetDoubleFromPayload failed"); err != nil { + return nil, err } slice := (*[1 << 28]float64)(unsafe.Pointer(cMsg))[:cSize:cSize] @@ -630,13 +525,9 @@ func (r *PayloadReader) GetOneStringFromPayload(idx int) (string, error) { var cStr *C.char var cSize C.int - st := C.GetOneStringFromPayload(r.payloadReaderPtr, C.int(idx), &cStr, &cSize) - - errCode := commonpb.ErrorCode(st.error_code) - if errCode != commonpb.ErrorCode_Success { - msg := C.GoString(st.error_msg) - defer C.free(unsafe.Pointer(st.error_msg)) - return "", errors.New(msg) + status := C.GetOneStringFromPayload(r.payloadReaderPtr, C.int(idx), &cStr, &cSize) + if err := HandleCStatus(&status, "GetOneStringFromPayload failed"); err != nil { + return "", err } return C.GoStringN(cStr, cSize), nil } @@ -651,12 +542,9 @@ func (r *PayloadReader) GetBinaryVectorFromPayload() ([]byte, int, error) { var cDim C.int var cLen C.int - st := C.GetBinaryVectorFromPayload(r.payloadReaderPtr, &cMsg, &cDim, &cLen) - errCode := commonpb.ErrorCode(st.error_code) - if errCode != commonpb.ErrorCode_Success { - msg := C.GoString(st.error_msg) - defer C.free(unsafe.Pointer(st.error_msg)) - return nil, 0, errors.New(msg) + status := C.GetBinaryVectorFromPayload(r.payloadReaderPtr, &cMsg, &cDim, &cLen) + if err := HandleCStatus(&status, "GetBinaryVectorFromPayload failed"); err != nil { + return nil, 0, err } length := (cDim / 8) * cLen @@ -674,12 +562,9 @@ func (r *PayloadReader) GetFloatVectorFromPayload() ([]float32, int, error) { var cDim C.int var cLen C.int - st := C.GetFloatVectorFromPayload(r.payloadReaderPtr, &cMsg, &cDim, &cLen) - errCode := commonpb.ErrorCode(st.error_code) - if errCode != commonpb.ErrorCode_Success { - msg := C.GoString(st.error_msg) - defer C.free(unsafe.Pointer(st.error_msg)) - return nil, 0, errors.New(msg) + status := C.GetFloatVectorFromPayload(r.payloadReaderPtr, &cMsg, &cDim, &cLen) + if err := HandleCStatus(&status, "GetFloatVectorFromPayload failed"); err != nil { + return nil, 0, err } length := cDim * cLen @@ -695,3 +580,22 @@ func (r *PayloadReader) GetPayloadLengthFromReader() (int, error) { func (r *PayloadReader) Close() error { return r.ReleasePayloadReader() } + +// HandleCStatus deal with the error returned from CGO +func HandleCStatus(status *C.CStatus, extraInfo string) error { + if status.error_code == 0 { + return nil + } + errorCode := status.error_code + errorName, ok := commonpb.ErrorCode_name[int32(errorCode)] + if !ok { + errorName = "UnknownError" + } + errorMsg := C.GoString(status.error_msg) + defer C.free(unsafe.Pointer(status.error_msg)) + + finalMsg := fmt.Sprintf("[%s] %s", errorName, errorMsg) + logMsg := fmt.Sprintf("%s, C Runtime Exception: %s\n", extraInfo, finalMsg) + log.Warn(logMsg) + return errors.New(finalMsg) +}