enhance: move segcore codes of segment into one package (#37722)

issue: #33285

- move most cgo opeartions related to search/query into segcore package
for reusing for streamingnode.
- add go unittest for segcore operations.

Signed-off-by: chyezh <chyezh@outlook.com>
This commit is contained in:
Zhen Ye 2024-11-29 10:22:36 +08:00 committed by GitHub
parent 843c1f506f
commit c6dcef7b84
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
55 changed files with 2399 additions and 859 deletions

View File

@ -147,6 +147,8 @@ issues:
- path: .+_test\.go - path: .+_test\.go
linters: linters:
- forbidigo - forbidigo
- path: mocks\/(.)+mock_(.+)\.go
text: "don't use an underscore in package name"
exclude: exclude:
- should have a package comment - should have a package comment
- should have comment - should have comment

View File

@ -473,7 +473,6 @@ generate-mockery-querycoord: getdeps
generate-mockery-querynode: getdeps build-cpp generate-mockery-querynode: getdeps build-cpp
@source $(PWD)/scripts/setenv.sh # setup PKG_CONFIG_PATH @source $(PWD)/scripts/setenv.sh # setup PKG_CONFIG_PATH
$(INSTALL_PATH)/mockery --name=QueryHook --dir=$(PWD)/internal/querynodev2/optimizers --output=$(PWD)/internal/querynodev2/optimizers --filename=mock_query_hook.go --with-expecter --outpkg=optimizers --structname=MockQueryHook --inpackage
$(INSTALL_PATH)/mockery --name=Manager --dir=$(PWD)/internal/querynodev2/cluster --output=$(PWD)/internal/querynodev2/cluster --filename=mock_manager.go --with-expecter --outpkg=cluster --structname=MockManager --inpackage $(INSTALL_PATH)/mockery --name=Manager --dir=$(PWD)/internal/querynodev2/cluster --output=$(PWD)/internal/querynodev2/cluster --filename=mock_manager.go --with-expecter --outpkg=cluster --structname=MockManager --inpackage
$(INSTALL_PATH)/mockery --name=SegmentManager --dir=$(PWD)/internal/querynodev2/segments --output=$(PWD)/internal/querynodev2/segments --filename=mock_segment_manager.go --with-expecter --outpkg=segments --structname=MockSegmentManager --inpackage $(INSTALL_PATH)/mockery --name=SegmentManager --dir=$(PWD)/internal/querynodev2/segments --output=$(PWD)/internal/querynodev2/segments --filename=mock_segment_manager.go --with-expecter --outpkg=segments --structname=MockSegmentManager --inpackage
$(INSTALL_PATH)/mockery --name=CollectionManager --dir=$(PWD)/internal/querynodev2/segments --output=$(PWD)/internal/querynodev2/segments --filename=mock_collection_manager.go --with-expecter --outpkg=segments --structname=MockCollectionManager --inpackage $(INSTALL_PATH)/mockery --name=CollectionManager --dir=$(PWD)/internal/querynodev2/segments --output=$(PWD)/internal/querynodev2/segments --filename=mock_collection_manager.go --with-expecter --outpkg=segments --structname=MockCollectionManager --inpackage

View File

@ -61,6 +61,9 @@ packages:
interfaces: interfaces:
StreamingCoordCataLog: StreamingCoordCataLog:
StreamingNodeCataLog: StreamingNodeCataLog:
github.com/milvus-io/milvus/internal/util/segcore:
interfaces:
CSegment:
github.com/milvus-io/milvus/internal/util/streamingutil/service/discoverer: github.com/milvus-io/milvus/internal/util/streamingutil/service/discoverer:
interfaces: interfaces:
Discoverer: Discoverer:
@ -72,6 +75,9 @@ packages:
interfaces: interfaces:
Resolver: Resolver:
Builder: Builder:
github.com/milvus-io/milvus/internal/util/searchutil/optimizers:
interfaces:
QueryHook:
google.golang.org/grpc/resolver: google.golang.org/grpc/resolver:
interfaces: interfaces:
ClientConn: ClientConn:

View File

@ -11,6 +11,9 @@
#pragma once #pragma once
#include <stdbool.h>
#include <stdint.h>
#ifdef __cplusplus #ifdef __cplusplus
extern "C" { extern "C" {
#endif #endif

View File

@ -0,0 +1,708 @@
// Code generated by mockery v2.46.0. DO NOT EDIT.
package mock_segcore
import (
context "context"
segcore "github.com/milvus-io/milvus/internal/util/segcore"
mock "github.com/stretchr/testify/mock"
)
// MockCSegment is an autogenerated mock type for the CSegment type
type MockCSegment struct {
mock.Mock
}
type MockCSegment_Expecter struct {
mock *mock.Mock
}
func (_m *MockCSegment) EXPECT() *MockCSegment_Expecter {
return &MockCSegment_Expecter{mock: &_m.Mock}
}
// AddFieldDataInfo provides a mock function with given fields: ctx, request
func (_m *MockCSegment) AddFieldDataInfo(ctx context.Context, request *segcore.LoadFieldDataRequest) (*segcore.AddFieldDataInfoResult, error) {
ret := _m.Called(ctx, request)
if len(ret) == 0 {
panic("no return value specified for AddFieldDataInfo")
}
var r0 *segcore.AddFieldDataInfoResult
var r1 error
if rf, ok := ret.Get(0).(func(context.Context, *segcore.LoadFieldDataRequest) (*segcore.AddFieldDataInfoResult, error)); ok {
return rf(ctx, request)
}
if rf, ok := ret.Get(0).(func(context.Context, *segcore.LoadFieldDataRequest) *segcore.AddFieldDataInfoResult); ok {
r0 = rf(ctx, request)
} else {
if ret.Get(0) != nil {
r0 = ret.Get(0).(*segcore.AddFieldDataInfoResult)
}
}
if rf, ok := ret.Get(1).(func(context.Context, *segcore.LoadFieldDataRequest) error); ok {
r1 = rf(ctx, request)
} else {
r1 = ret.Error(1)
}
return r0, r1
}
// MockCSegment_AddFieldDataInfo_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'AddFieldDataInfo'
type MockCSegment_AddFieldDataInfo_Call struct {
*mock.Call
}
// AddFieldDataInfo is a helper method to define mock.On call
// - ctx context.Context
// - request *segcore.LoadFieldDataRequest
func (_e *MockCSegment_Expecter) AddFieldDataInfo(ctx interface{}, request interface{}) *MockCSegment_AddFieldDataInfo_Call {
return &MockCSegment_AddFieldDataInfo_Call{Call: _e.mock.On("AddFieldDataInfo", ctx, request)}
}
func (_c *MockCSegment_AddFieldDataInfo_Call) Run(run func(ctx context.Context, request *segcore.LoadFieldDataRequest)) *MockCSegment_AddFieldDataInfo_Call {
_c.Call.Run(func(args mock.Arguments) {
run(args[0].(context.Context), args[1].(*segcore.LoadFieldDataRequest))
})
return _c
}
func (_c *MockCSegment_AddFieldDataInfo_Call) Return(_a0 *segcore.AddFieldDataInfoResult, _a1 error) *MockCSegment_AddFieldDataInfo_Call {
_c.Call.Return(_a0, _a1)
return _c
}
func (_c *MockCSegment_AddFieldDataInfo_Call) RunAndReturn(run func(context.Context, *segcore.LoadFieldDataRequest) (*segcore.AddFieldDataInfoResult, error)) *MockCSegment_AddFieldDataInfo_Call {
_c.Call.Return(run)
return _c
}
// Delete provides a mock function with given fields: ctx, request
func (_m *MockCSegment) Delete(ctx context.Context, request *segcore.DeleteRequest) (*segcore.DeleteResult, error) {
ret := _m.Called(ctx, request)
if len(ret) == 0 {
panic("no return value specified for Delete")
}
var r0 *segcore.DeleteResult
var r1 error
if rf, ok := ret.Get(0).(func(context.Context, *segcore.DeleteRequest) (*segcore.DeleteResult, error)); ok {
return rf(ctx, request)
}
if rf, ok := ret.Get(0).(func(context.Context, *segcore.DeleteRequest) *segcore.DeleteResult); ok {
r0 = rf(ctx, request)
} else {
if ret.Get(0) != nil {
r0 = ret.Get(0).(*segcore.DeleteResult)
}
}
if rf, ok := ret.Get(1).(func(context.Context, *segcore.DeleteRequest) error); ok {
r1 = rf(ctx, request)
} else {
r1 = ret.Error(1)
}
return r0, r1
}
// MockCSegment_Delete_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'Delete'
type MockCSegment_Delete_Call struct {
*mock.Call
}
// Delete is a helper method to define mock.On call
// - ctx context.Context
// - request *segcore.DeleteRequest
func (_e *MockCSegment_Expecter) Delete(ctx interface{}, request interface{}) *MockCSegment_Delete_Call {
return &MockCSegment_Delete_Call{Call: _e.mock.On("Delete", ctx, request)}
}
func (_c *MockCSegment_Delete_Call) Run(run func(ctx context.Context, request *segcore.DeleteRequest)) *MockCSegment_Delete_Call {
_c.Call.Run(func(args mock.Arguments) {
run(args[0].(context.Context), args[1].(*segcore.DeleteRequest))
})
return _c
}
func (_c *MockCSegment_Delete_Call) Return(_a0 *segcore.DeleteResult, _a1 error) *MockCSegment_Delete_Call {
_c.Call.Return(_a0, _a1)
return _c
}
func (_c *MockCSegment_Delete_Call) RunAndReturn(run func(context.Context, *segcore.DeleteRequest) (*segcore.DeleteResult, error)) *MockCSegment_Delete_Call {
_c.Call.Return(run)
return _c
}
// HasRawData provides a mock function with given fields: fieldID
func (_m *MockCSegment) HasRawData(fieldID int64) bool {
ret := _m.Called(fieldID)
if len(ret) == 0 {
panic("no return value specified for HasRawData")
}
var r0 bool
if rf, ok := ret.Get(0).(func(int64) bool); ok {
r0 = rf(fieldID)
} else {
r0 = ret.Get(0).(bool)
}
return r0
}
// MockCSegment_HasRawData_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'HasRawData'
type MockCSegment_HasRawData_Call struct {
*mock.Call
}
// HasRawData is a helper method to define mock.On call
// - fieldID int64
func (_e *MockCSegment_Expecter) HasRawData(fieldID interface{}) *MockCSegment_HasRawData_Call {
return &MockCSegment_HasRawData_Call{Call: _e.mock.On("HasRawData", fieldID)}
}
func (_c *MockCSegment_HasRawData_Call) Run(run func(fieldID int64)) *MockCSegment_HasRawData_Call {
_c.Call.Run(func(args mock.Arguments) {
run(args[0].(int64))
})
return _c
}
func (_c *MockCSegment_HasRawData_Call) Return(_a0 bool) *MockCSegment_HasRawData_Call {
_c.Call.Return(_a0)
return _c
}
func (_c *MockCSegment_HasRawData_Call) RunAndReturn(run func(int64) bool) *MockCSegment_HasRawData_Call {
_c.Call.Return(run)
return _c
}
// ID provides a mock function with given fields:
func (_m *MockCSegment) ID() int64 {
ret := _m.Called()
if len(ret) == 0 {
panic("no return value specified for ID")
}
var r0 int64
if rf, ok := ret.Get(0).(func() int64); ok {
r0 = rf()
} else {
r0 = ret.Get(0).(int64)
}
return r0
}
// MockCSegment_ID_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'ID'
type MockCSegment_ID_Call struct {
*mock.Call
}
// ID is a helper method to define mock.On call
func (_e *MockCSegment_Expecter) ID() *MockCSegment_ID_Call {
return &MockCSegment_ID_Call{Call: _e.mock.On("ID")}
}
func (_c *MockCSegment_ID_Call) Run(run func()) *MockCSegment_ID_Call {
_c.Call.Run(func(args mock.Arguments) {
run()
})
return _c
}
func (_c *MockCSegment_ID_Call) Return(_a0 int64) *MockCSegment_ID_Call {
_c.Call.Return(_a0)
return _c
}
func (_c *MockCSegment_ID_Call) RunAndReturn(run func() int64) *MockCSegment_ID_Call {
_c.Call.Return(run)
return _c
}
// Insert provides a mock function with given fields: ctx, request
func (_m *MockCSegment) Insert(ctx context.Context, request *segcore.InsertRequest) (*segcore.InsertResult, error) {
ret := _m.Called(ctx, request)
if len(ret) == 0 {
panic("no return value specified for Insert")
}
var r0 *segcore.InsertResult
var r1 error
if rf, ok := ret.Get(0).(func(context.Context, *segcore.InsertRequest) (*segcore.InsertResult, error)); ok {
return rf(ctx, request)
}
if rf, ok := ret.Get(0).(func(context.Context, *segcore.InsertRequest) *segcore.InsertResult); ok {
r0 = rf(ctx, request)
} else {
if ret.Get(0) != nil {
r0 = ret.Get(0).(*segcore.InsertResult)
}
}
if rf, ok := ret.Get(1).(func(context.Context, *segcore.InsertRequest) error); ok {
r1 = rf(ctx, request)
} else {
r1 = ret.Error(1)
}
return r0, r1
}
// MockCSegment_Insert_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'Insert'
type MockCSegment_Insert_Call struct {
*mock.Call
}
// Insert is a helper method to define mock.On call
// - ctx context.Context
// - request *segcore.InsertRequest
func (_e *MockCSegment_Expecter) Insert(ctx interface{}, request interface{}) *MockCSegment_Insert_Call {
return &MockCSegment_Insert_Call{Call: _e.mock.On("Insert", ctx, request)}
}
func (_c *MockCSegment_Insert_Call) Run(run func(ctx context.Context, request *segcore.InsertRequest)) *MockCSegment_Insert_Call {
_c.Call.Run(func(args mock.Arguments) {
run(args[0].(context.Context), args[1].(*segcore.InsertRequest))
})
return _c
}
func (_c *MockCSegment_Insert_Call) Return(_a0 *segcore.InsertResult, _a1 error) *MockCSegment_Insert_Call {
_c.Call.Return(_a0, _a1)
return _c
}
func (_c *MockCSegment_Insert_Call) RunAndReturn(run func(context.Context, *segcore.InsertRequest) (*segcore.InsertResult, error)) *MockCSegment_Insert_Call {
_c.Call.Return(run)
return _c
}
// LoadFieldData provides a mock function with given fields: ctx, request
func (_m *MockCSegment) LoadFieldData(ctx context.Context, request *segcore.LoadFieldDataRequest) (*segcore.LoadFieldDataResult, error) {
ret := _m.Called(ctx, request)
if len(ret) == 0 {
panic("no return value specified for LoadFieldData")
}
var r0 *segcore.LoadFieldDataResult
var r1 error
if rf, ok := ret.Get(0).(func(context.Context, *segcore.LoadFieldDataRequest) (*segcore.LoadFieldDataResult, error)); ok {
return rf(ctx, request)
}
if rf, ok := ret.Get(0).(func(context.Context, *segcore.LoadFieldDataRequest) *segcore.LoadFieldDataResult); ok {
r0 = rf(ctx, request)
} else {
if ret.Get(0) != nil {
r0 = ret.Get(0).(*segcore.LoadFieldDataResult)
}
}
if rf, ok := ret.Get(1).(func(context.Context, *segcore.LoadFieldDataRequest) error); ok {
r1 = rf(ctx, request)
} else {
r1 = ret.Error(1)
}
return r0, r1
}
// MockCSegment_LoadFieldData_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'LoadFieldData'
type MockCSegment_LoadFieldData_Call struct {
*mock.Call
}
// LoadFieldData is a helper method to define mock.On call
// - ctx context.Context
// - request *segcore.LoadFieldDataRequest
func (_e *MockCSegment_Expecter) LoadFieldData(ctx interface{}, request interface{}) *MockCSegment_LoadFieldData_Call {
return &MockCSegment_LoadFieldData_Call{Call: _e.mock.On("LoadFieldData", ctx, request)}
}
func (_c *MockCSegment_LoadFieldData_Call) Run(run func(ctx context.Context, request *segcore.LoadFieldDataRequest)) *MockCSegment_LoadFieldData_Call {
_c.Call.Run(func(args mock.Arguments) {
run(args[0].(context.Context), args[1].(*segcore.LoadFieldDataRequest))
})
return _c
}
func (_c *MockCSegment_LoadFieldData_Call) Return(_a0 *segcore.LoadFieldDataResult, _a1 error) *MockCSegment_LoadFieldData_Call {
_c.Call.Return(_a0, _a1)
return _c
}
func (_c *MockCSegment_LoadFieldData_Call) RunAndReturn(run func(context.Context, *segcore.LoadFieldDataRequest) (*segcore.LoadFieldDataResult, error)) *MockCSegment_LoadFieldData_Call {
_c.Call.Return(run)
return _c
}
// MemSize provides a mock function with given fields:
func (_m *MockCSegment) MemSize() int64 {
ret := _m.Called()
if len(ret) == 0 {
panic("no return value specified for MemSize")
}
var r0 int64
if rf, ok := ret.Get(0).(func() int64); ok {
r0 = rf()
} else {
r0 = ret.Get(0).(int64)
}
return r0
}
// MockCSegment_MemSize_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'MemSize'
type MockCSegment_MemSize_Call struct {
*mock.Call
}
// MemSize is a helper method to define mock.On call
func (_e *MockCSegment_Expecter) MemSize() *MockCSegment_MemSize_Call {
return &MockCSegment_MemSize_Call{Call: _e.mock.On("MemSize")}
}
func (_c *MockCSegment_MemSize_Call) Run(run func()) *MockCSegment_MemSize_Call {
_c.Call.Run(func(args mock.Arguments) {
run()
})
return _c
}
func (_c *MockCSegment_MemSize_Call) Return(_a0 int64) *MockCSegment_MemSize_Call {
_c.Call.Return(_a0)
return _c
}
func (_c *MockCSegment_MemSize_Call) RunAndReturn(run func() int64) *MockCSegment_MemSize_Call {
_c.Call.Return(run)
return _c
}
// RawPointer provides a mock function with given fields:
func (_m *MockCSegment) RawPointer() segcore.CSegmentInterface {
ret := _m.Called()
if len(ret) == 0 {
panic("no return value specified for RawPointer")
}
var r0 segcore.CSegmentInterface
if rf, ok := ret.Get(0).(func() segcore.CSegmentInterface); ok {
r0 = rf()
} else {
r0 = ret.Get(0).(segcore.CSegmentInterface)
}
return r0
}
// MockCSegment_RawPointer_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'RawPointer'
type MockCSegment_RawPointer_Call struct {
*mock.Call
}
// RawPointer is a helper method to define mock.On call
func (_e *MockCSegment_Expecter) RawPointer() *MockCSegment_RawPointer_Call {
return &MockCSegment_RawPointer_Call{Call: _e.mock.On("RawPointer")}
}
func (_c *MockCSegment_RawPointer_Call) Run(run func()) *MockCSegment_RawPointer_Call {
_c.Call.Run(func(args mock.Arguments) {
run()
})
return _c
}
func (_c *MockCSegment_RawPointer_Call) Return(_a0 segcore.CSegmentInterface) *MockCSegment_RawPointer_Call {
_c.Call.Return(_a0)
return _c
}
func (_c *MockCSegment_RawPointer_Call) RunAndReturn(run func() segcore.CSegmentInterface) *MockCSegment_RawPointer_Call {
_c.Call.Return(run)
return _c
}
// Release provides a mock function with given fields:
func (_m *MockCSegment) Release() {
_m.Called()
}
// MockCSegment_Release_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'Release'
type MockCSegment_Release_Call struct {
*mock.Call
}
// Release is a helper method to define mock.On call
func (_e *MockCSegment_Expecter) Release() *MockCSegment_Release_Call {
return &MockCSegment_Release_Call{Call: _e.mock.On("Release")}
}
func (_c *MockCSegment_Release_Call) Run(run func()) *MockCSegment_Release_Call {
_c.Call.Run(func(args mock.Arguments) {
run()
})
return _c
}
func (_c *MockCSegment_Release_Call) Return() *MockCSegment_Release_Call {
_c.Call.Return()
return _c
}
func (_c *MockCSegment_Release_Call) RunAndReturn(run func()) *MockCSegment_Release_Call {
_c.Call.Return(run)
return _c
}
// Retrieve provides a mock function with given fields: ctx, plan
func (_m *MockCSegment) Retrieve(ctx context.Context, plan *segcore.RetrievePlan) (*segcore.RetrieveResult, error) {
ret := _m.Called(ctx, plan)
if len(ret) == 0 {
panic("no return value specified for Retrieve")
}
var r0 *segcore.RetrieveResult
var r1 error
if rf, ok := ret.Get(0).(func(context.Context, *segcore.RetrievePlan) (*segcore.RetrieveResult, error)); ok {
return rf(ctx, plan)
}
if rf, ok := ret.Get(0).(func(context.Context, *segcore.RetrievePlan) *segcore.RetrieveResult); ok {
r0 = rf(ctx, plan)
} else {
if ret.Get(0) != nil {
r0 = ret.Get(0).(*segcore.RetrieveResult)
}
}
if rf, ok := ret.Get(1).(func(context.Context, *segcore.RetrievePlan) error); ok {
r1 = rf(ctx, plan)
} else {
r1 = ret.Error(1)
}
return r0, r1
}
// MockCSegment_Retrieve_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'Retrieve'
type MockCSegment_Retrieve_Call struct {
*mock.Call
}
// Retrieve is a helper method to define mock.On call
// - ctx context.Context
// - plan *segcore.RetrievePlan
func (_e *MockCSegment_Expecter) Retrieve(ctx interface{}, plan interface{}) *MockCSegment_Retrieve_Call {
return &MockCSegment_Retrieve_Call{Call: _e.mock.On("Retrieve", ctx, plan)}
}
func (_c *MockCSegment_Retrieve_Call) Run(run func(ctx context.Context, plan *segcore.RetrievePlan)) *MockCSegment_Retrieve_Call {
_c.Call.Run(func(args mock.Arguments) {
run(args[0].(context.Context), args[1].(*segcore.RetrievePlan))
})
return _c
}
func (_c *MockCSegment_Retrieve_Call) Return(_a0 *segcore.RetrieveResult, _a1 error) *MockCSegment_Retrieve_Call {
_c.Call.Return(_a0, _a1)
return _c
}
func (_c *MockCSegment_Retrieve_Call) RunAndReturn(run func(context.Context, *segcore.RetrievePlan) (*segcore.RetrieveResult, error)) *MockCSegment_Retrieve_Call {
_c.Call.Return(run)
return _c
}
// RetrieveByOffsets provides a mock function with given fields: ctx, plan
func (_m *MockCSegment) RetrieveByOffsets(ctx context.Context, plan *segcore.RetrievePlanWithOffsets) (*segcore.RetrieveResult, error) {
ret := _m.Called(ctx, plan)
if len(ret) == 0 {
panic("no return value specified for RetrieveByOffsets")
}
var r0 *segcore.RetrieveResult
var r1 error
if rf, ok := ret.Get(0).(func(context.Context, *segcore.RetrievePlanWithOffsets) (*segcore.RetrieveResult, error)); ok {
return rf(ctx, plan)
}
if rf, ok := ret.Get(0).(func(context.Context, *segcore.RetrievePlanWithOffsets) *segcore.RetrieveResult); ok {
r0 = rf(ctx, plan)
} else {
if ret.Get(0) != nil {
r0 = ret.Get(0).(*segcore.RetrieveResult)
}
}
if rf, ok := ret.Get(1).(func(context.Context, *segcore.RetrievePlanWithOffsets) error); ok {
r1 = rf(ctx, plan)
} else {
r1 = ret.Error(1)
}
return r0, r1
}
// MockCSegment_RetrieveByOffsets_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'RetrieveByOffsets'
type MockCSegment_RetrieveByOffsets_Call struct {
*mock.Call
}
// RetrieveByOffsets is a helper method to define mock.On call
// - ctx context.Context
// - plan *segcore.RetrievePlanWithOffsets
func (_e *MockCSegment_Expecter) RetrieveByOffsets(ctx interface{}, plan interface{}) *MockCSegment_RetrieveByOffsets_Call {
return &MockCSegment_RetrieveByOffsets_Call{Call: _e.mock.On("RetrieveByOffsets", ctx, plan)}
}
func (_c *MockCSegment_RetrieveByOffsets_Call) Run(run func(ctx context.Context, plan *segcore.RetrievePlanWithOffsets)) *MockCSegment_RetrieveByOffsets_Call {
_c.Call.Run(func(args mock.Arguments) {
run(args[0].(context.Context), args[1].(*segcore.RetrievePlanWithOffsets))
})
return _c
}
func (_c *MockCSegment_RetrieveByOffsets_Call) Return(_a0 *segcore.RetrieveResult, _a1 error) *MockCSegment_RetrieveByOffsets_Call {
_c.Call.Return(_a0, _a1)
return _c
}
func (_c *MockCSegment_RetrieveByOffsets_Call) RunAndReturn(run func(context.Context, *segcore.RetrievePlanWithOffsets) (*segcore.RetrieveResult, error)) *MockCSegment_RetrieveByOffsets_Call {
_c.Call.Return(run)
return _c
}
// RowNum provides a mock function with given fields:
func (_m *MockCSegment) RowNum() int64 {
ret := _m.Called()
if len(ret) == 0 {
panic("no return value specified for RowNum")
}
var r0 int64
if rf, ok := ret.Get(0).(func() int64); ok {
r0 = rf()
} else {
r0 = ret.Get(0).(int64)
}
return r0
}
// MockCSegment_RowNum_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'RowNum'
type MockCSegment_RowNum_Call struct {
*mock.Call
}
// RowNum is a helper method to define mock.On call
func (_e *MockCSegment_Expecter) RowNum() *MockCSegment_RowNum_Call {
return &MockCSegment_RowNum_Call{Call: _e.mock.On("RowNum")}
}
func (_c *MockCSegment_RowNum_Call) Run(run func()) *MockCSegment_RowNum_Call {
_c.Call.Run(func(args mock.Arguments) {
run()
})
return _c
}
func (_c *MockCSegment_RowNum_Call) Return(_a0 int64) *MockCSegment_RowNum_Call {
_c.Call.Return(_a0)
return _c
}
func (_c *MockCSegment_RowNum_Call) RunAndReturn(run func() int64) *MockCSegment_RowNum_Call {
_c.Call.Return(run)
return _c
}
// Search provides a mock function with given fields: ctx, searchReq
func (_m *MockCSegment) Search(ctx context.Context, searchReq *segcore.SearchRequest) (*segcore.SearchResult, error) {
ret := _m.Called(ctx, searchReq)
if len(ret) == 0 {
panic("no return value specified for Search")
}
var r0 *segcore.SearchResult
var r1 error
if rf, ok := ret.Get(0).(func(context.Context, *segcore.SearchRequest) (*segcore.SearchResult, error)); ok {
return rf(ctx, searchReq)
}
if rf, ok := ret.Get(0).(func(context.Context, *segcore.SearchRequest) *segcore.SearchResult); ok {
r0 = rf(ctx, searchReq)
} else {
if ret.Get(0) != nil {
r0 = ret.Get(0).(*segcore.SearchResult)
}
}
if rf, ok := ret.Get(1).(func(context.Context, *segcore.SearchRequest) error); ok {
r1 = rf(ctx, searchReq)
} else {
r1 = ret.Error(1)
}
return r0, r1
}
// MockCSegment_Search_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'Search'
type MockCSegment_Search_Call struct {
*mock.Call
}
// Search is a helper method to define mock.On call
// - ctx context.Context
// - searchReq *segcore.SearchRequest
func (_e *MockCSegment_Expecter) Search(ctx interface{}, searchReq interface{}) *MockCSegment_Search_Call {
return &MockCSegment_Search_Call{Call: _e.mock.On("Search", ctx, searchReq)}
}
func (_c *MockCSegment_Search_Call) Run(run func(ctx context.Context, searchReq *segcore.SearchRequest)) *MockCSegment_Search_Call {
_c.Call.Run(func(args mock.Arguments) {
run(args[0].(context.Context), args[1].(*segcore.SearchRequest))
})
return _c
}
func (_c *MockCSegment_Search_Call) Return(_a0 *segcore.SearchResult, _a1 error) *MockCSegment_Search_Call {
_c.Call.Return(_a0, _a1)
return _c
}
func (_c *MockCSegment_Search_Call) RunAndReturn(run func(context.Context, *segcore.SearchRequest) (*segcore.SearchResult, error)) *MockCSegment_Search_Call {
_c.Call.Return(run)
return _c
}
// NewMockCSegment creates a new instance of MockCSegment. It also registers a testing interface on the mock and a cleanup function to assert the mocks expectations.
// The first argument is typically a *testing.T value.
func NewMockCSegment(t interface {
mock.TestingT
Cleanup(func())
}) *MockCSegment {
mock := &MockCSegment{}
mock.Mock.Test(t)
t.Cleanup(func() { mock.AssertExpectations(t) })
return mock
}

View File

@ -14,7 +14,7 @@
// See the License for the specific language governing permissions and // See the License for the specific language governing permissions and
// limitations under the License. // limitations under the License.
package segments package mock_segcore
import ( import (
"context" "context"
@ -43,6 +43,7 @@ import (
"github.com/milvus-io/milvus/internal/proto/segcorepb" "github.com/milvus-io/milvus/internal/proto/segcorepb"
storage "github.com/milvus-io/milvus/internal/storage" storage "github.com/milvus-io/milvus/internal/storage"
"github.com/milvus-io/milvus/internal/util/indexcgowrapper" "github.com/milvus-io/milvus/internal/util/indexcgowrapper"
"github.com/milvus-io/milvus/internal/util/segcore"
"github.com/milvus-io/milvus/pkg/common" "github.com/milvus-io/milvus/pkg/common"
"github.com/milvus-io/milvus/pkg/log" "github.com/milvus-io/milvus/pkg/log"
"github.com/milvus-io/milvus/pkg/mq/msgstream" "github.com/milvus-io/milvus/pkg/mq/msgstream"
@ -78,7 +79,7 @@ const (
rowIDFieldID = 0 rowIDFieldID = 0
timestampFieldID = 1 timestampFieldID = 1
metricTypeKey = common.MetricTypeKey metricTypeKey = common.MetricTypeKey
defaultDim = 128 DefaultDim = 128
defaultMetricType = metric.L2 defaultMetricType = metric.L2
dimKey = common.DimKey dimKey = common.DimKey
@ -89,7 +90,7 @@ const (
// ---------- unittest util functions ---------- // ---------- unittest util functions ----------
// gen collection schema for // gen collection schema for
type vecFieldParam struct { type vecFieldParam struct {
id int64 ID int64
dim int dim int
metricType string metricType string
vecType schemapb.DataType vecType schemapb.DataType
@ -97,125 +98,125 @@ type vecFieldParam struct {
} }
type constFieldParam struct { type constFieldParam struct {
id int64 ID int64
dataType schemapb.DataType dataType schemapb.DataType
fieldName string fieldName string
} }
var simpleFloatVecField = vecFieldParam{ var SimpleFloatVecField = vecFieldParam{
id: 100, ID: 100,
dim: defaultDim, dim: DefaultDim,
metricType: defaultMetricType, metricType: defaultMetricType,
vecType: schemapb.DataType_FloatVector, vecType: schemapb.DataType_FloatVector,
fieldName: "floatVectorField", fieldName: "floatVectorField",
} }
var simpleBinVecField = vecFieldParam{ var simpleBinVecField = vecFieldParam{
id: 101, ID: 101,
dim: defaultDim, dim: DefaultDim,
metricType: metric.JACCARD, metricType: metric.JACCARD,
vecType: schemapb.DataType_BinaryVector, vecType: schemapb.DataType_BinaryVector,
fieldName: "binVectorField", fieldName: "binVectorField",
} }
var simpleFloat16VecField = vecFieldParam{ var simpleFloat16VecField = vecFieldParam{
id: 112, ID: 112,
dim: defaultDim, dim: DefaultDim,
metricType: defaultMetricType, metricType: defaultMetricType,
vecType: schemapb.DataType_Float16Vector, vecType: schemapb.DataType_Float16Vector,
fieldName: "float16VectorField", fieldName: "float16VectorField",
} }
var simpleBFloat16VecField = vecFieldParam{ var simpleBFloat16VecField = vecFieldParam{
id: 113, ID: 113,
dim: defaultDim, dim: DefaultDim,
metricType: defaultMetricType, metricType: defaultMetricType,
vecType: schemapb.DataType_BFloat16Vector, vecType: schemapb.DataType_BFloat16Vector,
fieldName: "bfloat16VectorField", fieldName: "bfloat16VectorField",
} }
var simpleSparseFloatVectorField = vecFieldParam{ var SimpleSparseFloatVectorField = vecFieldParam{
id: 114, ID: 114,
metricType: metric.IP, metricType: metric.IP,
vecType: schemapb.DataType_SparseFloatVector, vecType: schemapb.DataType_SparseFloatVector,
fieldName: "sparseFloatVectorField", fieldName: "sparseFloatVectorField",
} }
var simpleBoolField = constFieldParam{ var simpleBoolField = constFieldParam{
id: 102, ID: 102,
dataType: schemapb.DataType_Bool, dataType: schemapb.DataType_Bool,
fieldName: "boolField", fieldName: "boolField",
} }
var simpleInt8Field = constFieldParam{ var simpleInt8Field = constFieldParam{
id: 103, ID: 103,
dataType: schemapb.DataType_Int8, dataType: schemapb.DataType_Int8,
fieldName: "int8Field", fieldName: "int8Field",
} }
var simpleInt16Field = constFieldParam{ var simpleInt16Field = constFieldParam{
id: 104, ID: 104,
dataType: schemapb.DataType_Int16, dataType: schemapb.DataType_Int16,
fieldName: "int16Field", fieldName: "int16Field",
} }
var simpleInt32Field = constFieldParam{ var simpleInt32Field = constFieldParam{
id: 105, ID: 105,
dataType: schemapb.DataType_Int32, dataType: schemapb.DataType_Int32,
fieldName: "int32Field", fieldName: "int32Field",
} }
var simpleInt64Field = constFieldParam{ var simpleInt64Field = constFieldParam{
id: 106, ID: 106,
dataType: schemapb.DataType_Int64, dataType: schemapb.DataType_Int64,
fieldName: "int64Field", fieldName: "int64Field",
} }
var simpleFloatField = constFieldParam{ var simpleFloatField = constFieldParam{
id: 107, ID: 107,
dataType: schemapb.DataType_Float, dataType: schemapb.DataType_Float,
fieldName: "floatField", fieldName: "floatField",
} }
var simpleDoubleField = constFieldParam{ var simpleDoubleField = constFieldParam{
id: 108, ID: 108,
dataType: schemapb.DataType_Double, dataType: schemapb.DataType_Double,
fieldName: "doubleField", fieldName: "doubleField",
} }
var simpleJSONField = constFieldParam{ var simpleJSONField = constFieldParam{
id: 109, ID: 109,
dataType: schemapb.DataType_JSON, dataType: schemapb.DataType_JSON,
fieldName: "jsonField", fieldName: "jsonField",
} }
var simpleArrayField = constFieldParam{ var simpleArrayField = constFieldParam{
id: 110, ID: 110,
dataType: schemapb.DataType_Array, dataType: schemapb.DataType_Array,
fieldName: "arrayField", fieldName: "arrayField",
} }
var simpleVarCharField = constFieldParam{ var simpleVarCharField = constFieldParam{
id: 111, ID: 111,
dataType: schemapb.DataType_VarChar, dataType: schemapb.DataType_VarChar,
fieldName: "varCharField", fieldName: "varCharField",
} }
var rowIDField = constFieldParam{ var RowIDField = constFieldParam{
id: rowIDFieldID, ID: rowIDFieldID,
dataType: schemapb.DataType_Int64, dataType: schemapb.DataType_Int64,
fieldName: "RowID", fieldName: "RowID",
} }
var timestampField = constFieldParam{ var timestampField = constFieldParam{
id: timestampFieldID, ID: timestampFieldID,
dataType: schemapb.DataType_Int64, dataType: schemapb.DataType_Int64,
fieldName: "Timestamp", fieldName: "Timestamp",
} }
func genConstantFieldSchema(param constFieldParam) *schemapb.FieldSchema { func genConstantFieldSchema(param constFieldParam) *schemapb.FieldSchema {
field := &schemapb.FieldSchema{ field := &schemapb.FieldSchema{
FieldID: param.id, FieldID: param.ID,
Name: param.fieldName, Name: param.fieldName,
IsPrimaryKey: false, IsPrimaryKey: false,
DataType: param.dataType, DataType: param.dataType,
@ -231,7 +232,7 @@ func genConstantFieldSchema(param constFieldParam) *schemapb.FieldSchema {
func genPKFieldSchema(param constFieldParam) *schemapb.FieldSchema { func genPKFieldSchema(param constFieldParam) *schemapb.FieldSchema {
field := &schemapb.FieldSchema{ field := &schemapb.FieldSchema{
FieldID: param.id, FieldID: param.ID,
Name: param.fieldName, Name: param.fieldName,
IsPrimaryKey: true, IsPrimaryKey: true,
DataType: param.dataType, DataType: param.dataType,
@ -247,7 +248,7 @@ func genPKFieldSchema(param constFieldParam) *schemapb.FieldSchema {
func genVectorFieldSchema(param vecFieldParam) *schemapb.FieldSchema { func genVectorFieldSchema(param vecFieldParam) *schemapb.FieldSchema {
fieldVec := &schemapb.FieldSchema{ fieldVec := &schemapb.FieldSchema{
FieldID: param.id, FieldID: param.ID,
Name: param.fieldName, Name: param.fieldName,
IsPrimaryKey: false, IsPrimaryKey: false,
DataType: param.vecType, DataType: param.vecType,
@ -270,11 +271,11 @@ func genVectorFieldSchema(param vecFieldParam) *schemapb.FieldSchema {
} }
func GenTestBM25CollectionSchema(collectionName string) *schemapb.CollectionSchema { func GenTestBM25CollectionSchema(collectionName string) *schemapb.CollectionSchema {
fieldRowID := genConstantFieldSchema(rowIDField) fieldRowID := genConstantFieldSchema(RowIDField)
fieldTimestamp := genConstantFieldSchema(timestampField) fieldTimestamp := genConstantFieldSchema(timestampField)
pkFieldSchema := genPKFieldSchema(simpleInt64Field) pkFieldSchema := genPKFieldSchema(simpleInt64Field)
textFieldSchema := genConstantFieldSchema(simpleVarCharField) textFieldSchema := genConstantFieldSchema(simpleVarCharField)
sparseFieldSchema := genVectorFieldSchema(simpleSparseFloatVectorField) sparseFieldSchema := genVectorFieldSchema(SimpleSparseFloatVectorField)
sparseFieldSchema.IsFunctionOutput = true sparseFieldSchema.IsFunctionOutput = true
schema := &schemapb.CollectionSchema{ schema := &schemapb.CollectionSchema{
@ -301,7 +302,7 @@ func GenTestBM25CollectionSchema(collectionName string) *schemapb.CollectionSche
// some tests do not yet support sparse float vector, see comments of // some tests do not yet support sparse float vector, see comments of
// GenSparseFloatVecDataset in indexcgowrapper/dataset.go // GenSparseFloatVecDataset in indexcgowrapper/dataset.go
func GenTestCollectionSchema(collectionName string, pkType schemapb.DataType, withSparse bool) *schemapb.CollectionSchema { func GenTestCollectionSchema(collectionName string, pkType schemapb.DataType, withSparse bool) *schemapb.CollectionSchema {
fieldRowID := genConstantFieldSchema(rowIDField) fieldRowID := genConstantFieldSchema(RowIDField)
fieldTimestamp := genConstantFieldSchema(timestampField) fieldTimestamp := genConstantFieldSchema(timestampField)
fieldBool := genConstantFieldSchema(simpleBoolField) fieldBool := genConstantFieldSchema(simpleBoolField)
fieldInt8 := genConstantFieldSchema(simpleInt8Field) fieldInt8 := genConstantFieldSchema(simpleInt8Field)
@ -312,7 +313,7 @@ func GenTestCollectionSchema(collectionName string, pkType schemapb.DataType, wi
// fieldArray := genConstantFieldSchema(simpleArrayField) // fieldArray := genConstantFieldSchema(simpleArrayField)
fieldJSON := genConstantFieldSchema(simpleJSONField) fieldJSON := genConstantFieldSchema(simpleJSONField)
fieldArray := genConstantFieldSchema(simpleArrayField) fieldArray := genConstantFieldSchema(simpleArrayField)
floatVecFieldSchema := genVectorFieldSchema(simpleFloatVecField) floatVecFieldSchema := genVectorFieldSchema(SimpleFloatVecField)
binVecFieldSchema := genVectorFieldSchema(simpleBinVecField) binVecFieldSchema := genVectorFieldSchema(simpleBinVecField)
float16VecFieldSchema := genVectorFieldSchema(simpleFloat16VecField) float16VecFieldSchema := genVectorFieldSchema(simpleFloat16VecField)
bfloat16VecFieldSchema := genVectorFieldSchema(simpleBFloat16VecField) bfloat16VecFieldSchema := genVectorFieldSchema(simpleBFloat16VecField)
@ -346,7 +347,7 @@ func GenTestCollectionSchema(collectionName string, pkType schemapb.DataType, wi
} }
if withSparse { if withSparse {
schema.Fields = append(schema.Fields, genVectorFieldSchema(simpleSparseFloatVectorField)) schema.Fields = append(schema.Fields, genVectorFieldSchema(SimpleSparseFloatVectorField))
} }
for i, field := range schema.GetFields() { for i, field := range schema.GetFields() {
@ -477,7 +478,7 @@ func SaveBinLog(ctx context.Context,
return nil, nil, err return nil, nil, err
} }
k := JoinIDPath(collectionID, partitionID, segmentID, fieldID) k := metautil.JoinIDPath(collectionID, partitionID, segmentID, fieldID)
key := path.Join(chunkManager.RootPath(), "insert-log", k) key := path.Join(chunkManager.RootPath(), "insert-log", k)
kvs[key] = blob.Value kvs[key] = blob.Value
fieldBinlog = append(fieldBinlog, &datapb.FieldBinlog{ fieldBinlog = append(fieldBinlog, &datapb.FieldBinlog{
@ -499,7 +500,7 @@ func SaveBinLog(ctx context.Context,
return nil, nil, err return nil, nil, err
} }
k := JoinIDPath(collectionID, partitionID, segmentID, fieldID) k := metautil.JoinIDPath(collectionID, partitionID, segmentID, fieldID)
key := path.Join(chunkManager.RootPath(), "stats-log", k) key := path.Join(chunkManager.RootPath(), "stats-log", k)
kvs[key] = blob.Value kvs[key] = blob.Value
statsBinlog = append(statsBinlog, &datapb.FieldBinlog{ statsBinlog = append(statsBinlog, &datapb.FieldBinlog{
@ -597,7 +598,7 @@ func genInsertData(msgLength int, schema *schemapb.CollectionSchema) (*storage.I
Data: testutils.GenerateJSONArray(msgLength), Data: testutils.GenerateJSONArray(msgLength),
} }
case schemapb.DataType_FloatVector: case schemapb.DataType_FloatVector:
dim := simpleFloatVecField.dim // if no dim specified, use simpleFloatVecField's dim dim := SimpleFloatVecField.dim // if no dim specified, use simpleFloatVecField's dim
insertData.Data[f.FieldID] = &storage.FloatVectorFieldData{ insertData.Data[f.FieldID] = &storage.FloatVectorFieldData{
Data: testutils.GenerateFloatVectors(msgLength, dim), Data: testutils.GenerateFloatVectors(msgLength, dim),
Dim: dim, Dim: dim,
@ -689,7 +690,7 @@ func SaveDeltaLog(collectionID int64,
pkFieldID := int64(106) pkFieldID := int64(106)
fieldBinlog := make([]*datapb.FieldBinlog, 0) fieldBinlog := make([]*datapb.FieldBinlog, 0)
log.Debug("[query node unittest] save delta log", zap.Int64("fieldID", pkFieldID)) log.Debug("[query node unittest] save delta log", zap.Int64("fieldID", pkFieldID))
key := JoinIDPath(collectionID, partitionID, segmentID, pkFieldID) key := metautil.JoinIDPath(collectionID, partitionID, segmentID, pkFieldID)
// keyPath := path.Join(defaultLocalStorage, "delta-log", key) // keyPath := path.Join(defaultLocalStorage, "delta-log", key)
keyPath := path.Join(cm.RootPath(), "delta-log", key) keyPath := path.Join(cm.RootPath(), "delta-log", key)
kvs[keyPath] = blob.Value kvs[keyPath] = blob.Value
@ -750,13 +751,13 @@ func GenAndSaveIndexV2(collectionID, partitionID, segmentID, buildID int64,
var dataset *indexcgowrapper.Dataset var dataset *indexcgowrapper.Dataset
switch fieldSchema.DataType { switch fieldSchema.DataType {
case schemapb.DataType_BinaryVector: case schemapb.DataType_BinaryVector:
dataset = indexcgowrapper.GenBinaryVecDataset(testutils.GenerateBinaryVectors(msgLength, defaultDim)) dataset = indexcgowrapper.GenBinaryVecDataset(testutils.GenerateBinaryVectors(msgLength, DefaultDim))
case schemapb.DataType_FloatVector: case schemapb.DataType_FloatVector:
dataset = indexcgowrapper.GenFloatVecDataset(testutils.GenerateFloatVectors(msgLength, defaultDim)) dataset = indexcgowrapper.GenFloatVecDataset(testutils.GenerateFloatVectors(msgLength, DefaultDim))
case schemapb.DataType_Float16Vector: case schemapb.DataType_Float16Vector:
dataset = indexcgowrapper.GenFloat16VecDataset(testutils.GenerateFloat16Vectors(msgLength, defaultDim)) dataset = indexcgowrapper.GenFloat16VecDataset(testutils.GenerateFloat16Vectors(msgLength, DefaultDim))
case schemapb.DataType_BFloat16Vector: case schemapb.DataType_BFloat16Vector:
dataset = indexcgowrapper.GenBFloat16VecDataset(testutils.GenerateBFloat16Vectors(msgLength, defaultDim)) dataset = indexcgowrapper.GenBFloat16VecDataset(testutils.GenerateBFloat16Vectors(msgLength, DefaultDim))
case schemapb.DataType_SparseFloatVector: case schemapb.DataType_SparseFloatVector:
contents, dim := testutils.GenerateSparseFloatVectorsData(msgLength) contents, dim := testutils.GenerateSparseFloatVectorsData(msgLength)
dataset = indexcgowrapper.GenSparseFloatVecDataset(&storage.SparseFloatVectorFieldData{ dataset = indexcgowrapper.GenSparseFloatVecDataset(&storage.SparseFloatVectorFieldData{
@ -806,14 +807,14 @@ func GenAndSaveIndexV2(collectionID, partitionID, segmentID, buildID int64,
return nil, err return nil, err
} }
} }
_, cCurrentIndexVersion := getIndexEngineVersion() indexVersion := segcore.GetIndexEngineInfo()
return &querypb.FieldIndexInfo{ return &querypb.FieldIndexInfo{
FieldID: fieldSchema.GetFieldID(), FieldID: fieldSchema.GetFieldID(),
IndexName: indexInfo.GetIndexName(), IndexName: indexInfo.GetIndexName(),
IndexParams: indexInfo.GetIndexParams(), IndexParams: indexInfo.GetIndexParams(),
IndexFilePaths: indexPaths, IndexFilePaths: indexPaths,
CurrentIndexVersion: cCurrentIndexVersion, CurrentIndexVersion: indexVersion.CurrentIndexVersion,
}, nil }, nil
} }
@ -826,7 +827,7 @@ func GenAndSaveIndex(collectionID, partitionID, segmentID, fieldID int64, msgLen
} }
defer index.Delete() defer index.Delete()
err = index.Build(indexcgowrapper.GenFloatVecDataset(testutils.GenerateFloatVectors(msgLength, defaultDim))) err = index.Build(indexcgowrapper.GenFloatVecDataset(testutils.GenerateFloatVectors(msgLength, DefaultDim)))
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -845,7 +846,7 @@ func GenAndSaveIndex(collectionID, partitionID, segmentID, fieldID int64, msgLen
collectionID, collectionID,
partitionID, partitionID,
segmentID, segmentID,
simpleFloatVecField.id, SimpleFloatVecField.ID,
indexParams, indexParams,
"querynode-test", "querynode-test",
0, 0,
@ -866,20 +867,20 @@ func GenAndSaveIndex(collectionID, partitionID, segmentID, fieldID int64, msgLen
return nil, err return nil, err
} }
} }
_, cCurrentIndexVersion := getIndexEngineVersion() indexEngineInfo := segcore.GetIndexEngineInfo()
return &querypb.FieldIndexInfo{ return &querypb.FieldIndexInfo{
FieldID: fieldID, FieldID: fieldID,
IndexName: "querynode-test", IndexName: "querynode-test",
IndexParams: funcutil.Map2KeyValuePair(indexParams), IndexParams: funcutil.Map2KeyValuePair(indexParams),
IndexFilePaths: indexPaths, IndexFilePaths: indexPaths,
CurrentIndexVersion: cCurrentIndexVersion, CurrentIndexVersion: indexEngineInfo.CurrentIndexVersion,
}, nil }, nil
} }
func genIndexParams(indexType, metricType string) (map[string]string, map[string]string) { func genIndexParams(indexType, metricType string) (map[string]string, map[string]string) {
typeParams := make(map[string]string) typeParams := make(map[string]string)
typeParams[common.DimKey] = strconv.Itoa(defaultDim) typeParams[common.DimKey] = strconv.Itoa(DefaultDim)
indexParams := make(map[string]string) indexParams := make(map[string]string)
indexParams[common.IndexTypeKey] = indexType indexParams[common.IndexTypeKey] = indexType
@ -927,7 +928,7 @@ func genStorageConfig() *indexpb.StorageConfig {
} }
} }
func genSearchRequest(nq int64, indexType string, collection *Collection) (*internalpb.SearchRequest, error) { func genSearchRequest(nq int64, indexType string, collection *segcore.CCollection) (*internalpb.SearchRequest, error) {
placeHolder, err := genPlaceHolderGroup(nq) placeHolder, err := genPlaceHolderGroup(nq)
if err != nil { if err != nil {
return nil, err return nil, err
@ -946,7 +947,6 @@ func genSearchRequest(nq int64, indexType string, collection *Collection) (*inte
return &internalpb.SearchRequest{ return &internalpb.SearchRequest{
Base: genCommonMsgBase(commonpb.MsgType_Search, 0), Base: genCommonMsgBase(commonpb.MsgType_Search, 0),
CollectionID: collection.ID(), CollectionID: collection.ID(),
PartitionIDs: collection.GetPartitions(),
PlaceholderGroup: placeHolder, PlaceholderGroup: placeHolder,
SerializedExprPlan: serializedPlan, SerializedExprPlan: serializedPlan,
DslType: commonpb.DslType_BoolExprV1, DslType: commonpb.DslType_BoolExprV1,
@ -969,8 +969,8 @@ func genPlaceHolderGroup(nq int64) ([]byte, error) {
Values: make([][]byte, 0), Values: make([][]byte, 0),
} }
for i := int64(0); i < nq; i++ { for i := int64(0); i < nq; i++ {
vec := make([]float32, defaultDim) vec := make([]float32, DefaultDim)
for j := 0; j < defaultDim; j++ { for j := 0; j < DefaultDim; j++ {
vec[j] = rand.Float32() vec[j] = rand.Float32()
} }
var rawData []byte var rawData []byte
@ -1070,22 +1070,22 @@ func genHNSWDSL(schema *schemapb.CollectionSchema, ef int, topK int64, roundDeci
>`, nil >`, nil
} }
func checkSearchResult(ctx context.Context, nq int64, plan *SearchPlan, searchResult *SearchResult) error { func CheckSearchResult(ctx context.Context, nq int64, plan *segcore.SearchPlan, searchResult *segcore.SearchResult) error {
searchResults := make([]*SearchResult, 0) searchResults := make([]*segcore.SearchResult, 0)
searchResults = append(searchResults, searchResult) searchResults = append(searchResults, searchResult)
topK := plan.getTopK() topK := plan.GetTopK()
sliceNQs := []int64{nq / 5, nq / 5, nq / 5, nq / 5, nq / 5} sliceNQs := []int64{nq / 5, nq / 5, nq / 5, nq / 5, nq / 5}
sliceTopKs := []int64{topK, topK / 2, topK, topK, topK / 2} sliceTopKs := []int64{topK, topK / 2, topK, topK, topK / 2}
sInfo := ParseSliceInfo(sliceNQs, sliceTopKs, nq) sInfo := segcore.ParseSliceInfo(sliceNQs, sliceTopKs, nq)
res, err := ReduceSearchResultsAndFillData(ctx, plan, searchResults, 1, sInfo.SliceNQs, sInfo.SliceTopKs) res, err := segcore.ReduceSearchResultsAndFillData(ctx, plan, searchResults, 1, sInfo.SliceNQs, sInfo.SliceTopKs)
if err != nil { if err != nil {
return err return err
} }
for i := 0; i < len(sInfo.SliceNQs); i++ { for i := 0; i < len(sInfo.SliceNQs); i++ {
blob, err := GetSearchResultDataBlob(ctx, res, i) blob, err := segcore.GetSearchResultDataBlob(ctx, res, i)
if err != nil { if err != nil {
return err return err
} }
@ -1114,12 +1114,14 @@ func checkSearchResult(ctx context.Context, nq int64, plan *SearchPlan, searchRe
} }
} }
DeleteSearchResults(searchResults) for _, searchResult := range searchResults {
DeleteSearchResultDataBlobs(res) searchResult.Release()
}
segcore.DeleteSearchResultDataBlobs(res)
return nil return nil
} }
func genSearchPlanAndRequests(collection *Collection, segments []int64, indexType string, nq int64) (*SearchRequest, error) { func GenSearchPlanAndRequests(collection *segcore.CCollection, segments []int64, indexType string, nq int64) (*segcore.SearchRequest, error) {
iReq, _ := genSearchRequest(nq, indexType, collection) iReq, _ := genSearchRequest(nq, indexType, collection)
queryReq := &querypb.SearchRequest{ queryReq := &querypb.SearchRequest{
Req: iReq, Req: iReq,
@ -1127,10 +1129,10 @@ func genSearchPlanAndRequests(collection *Collection, segments []int64, indexTyp
SegmentIDs: segments, SegmentIDs: segments,
Scope: querypb.DataScope_Historical, Scope: querypb.DataScope_Historical,
} }
return NewSearchRequest(context.Background(), collection, queryReq, queryReq.Req.GetPlaceholderGroup()) return segcore.NewSearchRequest(collection, queryReq, queryReq.Req.GetPlaceholderGroup())
} }
func genInsertMsg(collection *Collection, partitionID, segment int64, numRows int) (*msgstream.InsertMsg, error) { func GenInsertMsg(collection *segcore.CCollection, partitionID, segment int64, numRows int) (*msgstream.InsertMsg, error) {
fieldsData := make([]*schemapb.FieldData, 0) fieldsData := make([]*schemapb.FieldData, 0)
for _, f := range collection.Schema().Fields { for _, f := range collection.Schema().Fields {
@ -1156,7 +1158,7 @@ func genInsertMsg(collection *Collection, partitionID, segment int64, numRows in
case schemapb.DataType_JSON: case schemapb.DataType_JSON:
fieldsData = append(fieldsData, testutils.GenerateScalarFieldDataWithID(f.DataType, simpleJSONField.fieldName, f.GetFieldID(), numRows)) fieldsData = append(fieldsData, testutils.GenerateScalarFieldDataWithID(f.DataType, simpleJSONField.fieldName, f.GetFieldID(), numRows))
case schemapb.DataType_FloatVector: case schemapb.DataType_FloatVector:
dim := simpleFloatVecField.dim // if no dim specified, use simpleFloatVecField's dim dim := SimpleFloatVecField.dim // if no dim specified, use simpleFloatVecField's dim
fieldsData = append(fieldsData, testutils.GenerateVectorFieldDataWithID(f.DataType, f.Name, f.FieldID, numRows, dim)) fieldsData = append(fieldsData, testutils.GenerateVectorFieldDataWithID(f.DataType, f.Name, f.FieldID, numRows, dim))
case schemapb.DataType_BinaryVector: case schemapb.DataType_BinaryVector:
dim := simpleBinVecField.dim // if no dim specified, use simpleFloatVecField's dim dim := simpleBinVecField.dim // if no dim specified, use simpleFloatVecField's dim
@ -1227,14 +1229,14 @@ func genSimpleRowIDField(numRows int) []int64 {
return ids return ids
} }
func genSimpleRetrievePlan(collection *Collection) (*RetrievePlan, error) { func GenSimpleRetrievePlan(collection *segcore.CCollection) (*segcore.RetrievePlan, error) {
timestamp := storage.Timestamp(1000) timestamp := storage.Timestamp(1000)
planBytes, err := genSimpleRetrievePlanExpr(collection.schema.Load()) planBytes, err := genSimpleRetrievePlanExpr(collection.Schema())
if err != nil { if err != nil {
return nil, err return nil, err
} }
plan, err2 := NewRetrievePlan(context.Background(), collection, planBytes, timestamp, 100) plan, err2 := segcore.NewRetrievePlan(collection, planBytes, timestamp, 100)
return plan, err2 return plan, err2
} }
@ -1279,14 +1281,14 @@ func genSimpleRetrievePlanExpr(schema *schemapb.CollectionSchema) ([]byte, error
return planExpr, err return planExpr, err
} }
func genFieldData(fieldName string, fieldID int64, fieldType schemapb.DataType, fieldValue interface{}, dim int64) *schemapb.FieldData { func GenFieldData(fieldName string, fieldID int64, fieldType schemapb.DataType, fieldValue interface{}, dim int64) *schemapb.FieldData {
if fieldType < 100 { if fieldType < 100 {
return testutils.GenerateScalarFieldDataWithValue(fieldType, fieldName, fieldID, fieldValue) return testutils.GenerateScalarFieldDataWithValue(fieldType, fieldName, fieldID, fieldValue)
} }
return testutils.GenerateVectorFieldDataWithValue(fieldType, fieldName, fieldID, fieldValue, int(dim)) return testutils.GenerateVectorFieldDataWithValue(fieldType, fieldName, fieldID, fieldValue, int(dim))
} }
func genSearchResultData(nq int64, topk int64, ids []int64, scores []float32, topks []int64) *schemapb.SearchResultData { func GenSearchResultData(nq int64, topk int64, ids []int64, scores []float32, topks []int64) *schemapb.SearchResultData {
return &schemapb.SearchResultData{ return &schemapb.SearchResultData{
NumQueries: 1, NumQueries: 1,
TopK: topk, TopK: topk,

View File

@ -1,6 +1,6 @@
// Code generated by mockery v2.46.0. DO NOT EDIT. // Code generated by mockery v2.46.0. DO NOT EDIT.
package optimizers package mock_optimizers
import mock "github.com/stretchr/testify/mock" import mock "github.com/stretchr/testify/mock"

View File

@ -36,6 +36,7 @@ import (
"github.com/milvus-io/milvus-proto/go-api/v2/commonpb" "github.com/milvus-io/milvus-proto/go-api/v2/commonpb"
"github.com/milvus-io/milvus-proto/go-api/v2/msgpb" "github.com/milvus-io/milvus-proto/go-api/v2/msgpb"
"github.com/milvus-io/milvus-proto/go-api/v2/schemapb" "github.com/milvus-io/milvus-proto/go-api/v2/schemapb"
"github.com/milvus-io/milvus/internal/mocks/util/mock_segcore"
"github.com/milvus-io/milvus/internal/proto/datapb" "github.com/milvus-io/milvus/internal/proto/datapb"
"github.com/milvus-io/milvus/internal/proto/internalpb" "github.com/milvus-io/milvus/internal/proto/internalpb"
"github.com/milvus-io/milvus/internal/proto/planpb" "github.com/milvus-io/milvus/internal/proto/planpb"
@ -1519,7 +1520,7 @@ func (s *DelegatorDataSuite) TestLevel0Deletions() {
err = allPartitionDeleteData.Append(storage.NewInt64PrimaryKey(2), 101) err = allPartitionDeleteData.Append(storage.NewInt64PrimaryKey(2), 101)
s.Require().NoError(err) s.Require().NoError(err)
schema := segments.GenTestCollectionSchema("test_stop", schemapb.DataType_Int64, true) schema := mock_segcore.GenTestCollectionSchema("test_stop", schemapb.DataType_Int64, true)
collection := segments.NewCollection(1, schema, nil, &querypb.LoadMetaInfo{ collection := segments.NewCollection(1, schema, nil, &querypb.LoadMetaInfo{
LoadType: querypb.LoadType_LoadCollection, LoadType: querypb.LoadType_LoadCollection,
}) })

View File

@ -27,6 +27,7 @@ import (
"github.com/milvus-io/milvus-proto/go-api/v2/commonpb" "github.com/milvus-io/milvus-proto/go-api/v2/commonpb"
"github.com/milvus-io/milvus-proto/go-api/v2/schemapb" "github.com/milvus-io/milvus-proto/go-api/v2/schemapb"
"github.com/milvus-io/milvus/internal/mocks/util/mock_segcore"
"github.com/milvus-io/milvus/internal/proto/indexpb" "github.com/milvus-io/milvus/internal/proto/indexpb"
"github.com/milvus-io/milvus/internal/proto/querypb" "github.com/milvus-io/milvus/internal/proto/querypb"
"github.com/milvus-io/milvus/internal/proto/segcorepb" "github.com/milvus-io/milvus/internal/proto/segcorepb"
@ -93,8 +94,8 @@ func (suite *LocalWorkerTestSuite) BeforeTest(suiteName, testName string) {
err = suite.node.Start() err = suite.node.Start()
suite.NoError(err) suite.NoError(err)
suite.schema = segments.GenTestCollectionSchema(suite.collectionName, schemapb.DataType_Int64, true) suite.schema = mock_segcore.GenTestCollectionSchema(suite.collectionName, schemapb.DataType_Int64, true)
suite.indexMeta = segments.GenTestIndexMeta(suite.collectionID, suite.schema) suite.indexMeta = mock_segcore.GenTestIndexMeta(suite.collectionID, suite.schema)
collection := segments.NewCollection(suite.collectionID, suite.schema, suite.indexMeta, &querypb.LoadMetaInfo{ collection := segments.NewCollection(suite.collectionID, suite.schema, suite.indexMeta, &querypb.LoadMetaInfo{
LoadType: querypb.LoadType_LoadCollection, LoadType: querypb.LoadType_LoadCollection,
}) })
@ -114,7 +115,7 @@ func (suite *LocalWorkerTestSuite) AfterTest(suiteName, testName string) {
func (suite *LocalWorkerTestSuite) TestLoadSegment() { func (suite *LocalWorkerTestSuite) TestLoadSegment() {
// load empty // load empty
schema := segments.GenTestCollectionSchema(suite.collectionName, schemapb.DataType_Int64, true) schema := mock_segcore.GenTestCollectionSchema(suite.collectionName, schemapb.DataType_Int64, true)
req := &querypb.LoadSegmentsRequest{ req := &querypb.LoadSegmentsRequest{
Base: &commonpb.MsgBase{ Base: &commonpb.MsgBase{
TargetID: suite.node.session.GetServerID(), TargetID: suite.node.session.GetServerID(),

View File

@ -24,6 +24,7 @@ import (
"github.com/stretchr/testify/suite" "github.com/stretchr/testify/suite"
"github.com/milvus-io/milvus-proto/go-api/v2/schemapb" "github.com/milvus-io/milvus-proto/go-api/v2/schemapb"
"github.com/milvus-io/milvus/internal/mocks/util/mock_segcore"
"github.com/milvus-io/milvus/internal/proto/querypb" "github.com/milvus-io/milvus/internal/proto/querypb"
"github.com/milvus-io/milvus/internal/querynodev2/delegator" "github.com/milvus-io/milvus/internal/querynodev2/delegator"
"github.com/milvus-io/milvus/internal/querynodev2/segments" "github.com/milvus-io/milvus/internal/querynodev2/segments"
@ -58,10 +59,10 @@ func (suite *InsertNodeSuite) SetupSuite() {
func (suite *InsertNodeSuite) TestBasic() { func (suite *InsertNodeSuite) TestBasic() {
// data // data
schema := segments.GenTestCollectionSchema(suite.collectionName, schemapb.DataType_Int64, true) schema := mock_segcore.GenTestCollectionSchema(suite.collectionName, schemapb.DataType_Int64, true)
in := suite.buildInsertNodeMsg(schema) in := suite.buildInsertNodeMsg(schema)
collection := segments.NewCollection(suite.collectionID, schema, segments.GenTestIndexMeta(suite.collectionID, schema), &querypb.LoadMetaInfo{ collection := segments.NewCollection(suite.collectionID, schema, mock_segcore.GenTestIndexMeta(suite.collectionID, schema), &querypb.LoadMetaInfo{
LoadType: querypb.LoadType_LoadCollection, LoadType: querypb.LoadType_LoadCollection,
}) })
collection.AddPartition(suite.partitionID) collection.AddPartition(suite.partitionID)
@ -94,10 +95,10 @@ func (suite *InsertNodeSuite) TestBasic() {
} }
func (suite *InsertNodeSuite) TestDataTypeNotSupported() { func (suite *InsertNodeSuite) TestDataTypeNotSupported() {
schema := segments.GenTestCollectionSchema(suite.collectionName, schemapb.DataType_Int64, true) schema := mock_segcore.GenTestCollectionSchema(suite.collectionName, schemapb.DataType_Int64, true)
in := suite.buildInsertNodeMsg(schema) in := suite.buildInsertNodeMsg(schema)
collection := segments.NewCollection(suite.collectionID, schema, segments.GenTestIndexMeta(suite.collectionID, schema), &querypb.LoadMetaInfo{ collection := segments.NewCollection(suite.collectionID, schema, mock_segcore.GenTestIndexMeta(suite.collectionID, schema), &querypb.LoadMetaInfo{
LoadType: querypb.LoadType_LoadCollection, LoadType: querypb.LoadType_LoadCollection,
}) })
collection.AddPartition(suite.partitionID) collection.AddPartition(suite.partitionID)

View File

@ -26,6 +26,7 @@ import (
"github.com/milvus-io/milvus-proto/go-api/v2/msgpb" "github.com/milvus-io/milvus-proto/go-api/v2/msgpb"
"github.com/milvus-io/milvus-proto/go-api/v2/schemapb" "github.com/milvus-io/milvus-proto/go-api/v2/schemapb"
"github.com/milvus-io/milvus/internal/mocks/util/mock_segcore"
"github.com/milvus-io/milvus/internal/proto/querypb" "github.com/milvus-io/milvus/internal/proto/querypb"
"github.com/milvus-io/milvus/internal/querynodev2/delegator" "github.com/milvus-io/milvus/internal/querynodev2/delegator"
"github.com/milvus-io/milvus/internal/querynodev2/segments" "github.com/milvus-io/milvus/internal/querynodev2/segments"
@ -108,8 +109,8 @@ func (suite *PipelineTestSuite) SetupTest() {
func (suite *PipelineTestSuite) TestBasic() { func (suite *PipelineTestSuite) TestBasic() {
// init mock // init mock
// mock collection manager // mock collection manager
schema := segments.GenTestCollectionSchema(suite.collectionName, schemapb.DataType_Int64, true) schema := mock_segcore.GenTestCollectionSchema(suite.collectionName, schemapb.DataType_Int64, true)
collection := segments.NewCollection(suite.collectionID, schema, segments.GenTestIndexMeta(suite.collectionID, schema), &querypb.LoadMetaInfo{ collection := segments.NewCollection(suite.collectionID, schema, mock_segcore.GenTestIndexMeta(suite.collectionID, schema), &querypb.LoadMetaInfo{
LoadType: querypb.LoadType_LoadCollection, LoadType: querypb.LoadType_LoadCollection,
}) })
suite.collectionManager.EXPECT().Get(suite.collectionID).Return(collection) suite.collectionManager.EXPECT().Get(suite.collectionID).Return(collection)

View File

@ -28,13 +28,10 @@ import "C"
import ( import (
"context" "context"
"math"
"unsafe" "unsafe"
"go.uber.org/zap" "go.uber.org/zap"
"google.golang.org/protobuf/proto"
"github.com/milvus-io/milvus/internal/util/cgoconverter"
"github.com/milvus-io/milvus/pkg/log" "github.com/milvus-io/milvus/pkg/log"
"github.com/milvus-io/milvus/pkg/util/merr" "github.com/milvus-io/milvus/pkg/util/merr"
) )
@ -55,38 +52,3 @@ func HandleCStatus(ctx context.Context, status *C.CStatus, extraInfo string, fie
log.Warn("CStatus returns err", zap.Error(err), zap.String("extra", extraInfo)) log.Warn("CStatus returns err", zap.Error(err), zap.String("extra", extraInfo))
return err return err
} }
// UnmarshalCProto unmarshal the proto from C memory
func UnmarshalCProto(cRes *C.CProto, msg proto.Message) error {
blob := (*(*[math.MaxInt32]byte)(cRes.proto_blob))[:int(cRes.proto_size):int(cRes.proto_size)]
return proto.Unmarshal(blob, msg)
}
// CopyCProtoBlob returns the copy of C memory
func CopyCProtoBlob(cProto *C.CProto) []byte {
blob := C.GoBytes(cProto.proto_blob, C.int32_t(cProto.proto_size))
C.free(cProto.proto_blob)
return blob
}
// GetCProtoBlob returns the raw C memory, invoker should release it itself
func GetCProtoBlob(cProto *C.CProto) []byte {
lease, blob := cgoconverter.UnsafeGoBytes(&cProto.proto_blob, int(cProto.proto_size))
cgoconverter.Extract(lease)
return blob
}
func GetLocalUsedSize(ctx context.Context, path string) (int64, error) {
var availableSize int64
cSize := (*C.int64_t)(&availableSize)
cPath := C.CString(path)
defer C.free(unsafe.Pointer(cPath))
status := C.GetLocalUsedSize(cPath, cSize)
err := HandleCStatus(ctx, &status, "get local used size failed")
if err != nil {
return 0, err
}
return availableSize, nil
}

View File

@ -16,27 +16,18 @@
package segments package segments
/*
#cgo pkg-config: milvus_core
#include "segcore/collection_c.h"
#include "segcore/segment_c.h"
*/
import "C"
import ( import (
"sync" "sync"
"unsafe"
"github.com/samber/lo" "github.com/samber/lo"
"go.uber.org/atomic" "go.uber.org/atomic"
"go.uber.org/zap" "go.uber.org/zap"
"google.golang.org/protobuf/proto"
"github.com/milvus-io/milvus-proto/go-api/v2/commonpb" "github.com/milvus-io/milvus-proto/go-api/v2/commonpb"
"github.com/milvus-io/milvus-proto/go-api/v2/schemapb" "github.com/milvus-io/milvus-proto/go-api/v2/schemapb"
"github.com/milvus-io/milvus/internal/proto/querypb" "github.com/milvus-io/milvus/internal/proto/querypb"
"github.com/milvus-io/milvus/internal/proto/segcorepb" "github.com/milvus-io/milvus/internal/proto/segcorepb"
"github.com/milvus-io/milvus/internal/util/segcore"
"github.com/milvus-io/milvus/internal/util/vecindexmgr" "github.com/milvus-io/milvus/internal/util/vecindexmgr"
"github.com/milvus-io/milvus/pkg/common" "github.com/milvus-io/milvus/pkg/common"
"github.com/milvus-io/milvus/pkg/log" "github.com/milvus-io/milvus/pkg/log"
@ -145,7 +136,7 @@ func (m *collectionManager) Unref(collectionID int64, count uint32) bool {
// In a query node, `Collection` is a replica info of a collection in these query node. // In a query node, `Collection` is a replica info of a collection in these query node.
type Collection struct { type Collection struct {
mu sync.RWMutex // protects colllectionPtr mu sync.RWMutex // protects colllectionPtr
collectionPtr C.CCollection ccollection *segcore.CCollection
id int64 id int64
partitions *typeutil.ConcurrentSet[int64] partitions *typeutil.ConcurrentSet[int64]
loadType querypb.LoadType loadType querypb.LoadType
@ -178,6 +169,11 @@ func (c *Collection) ID() int64 {
return c.id return c.id
} }
// GetCCollection returns the CCollection of collection
func (c *Collection) GetCCollection() *segcore.CCollection {
return c.ccollection
}
// Schema returns the schema of collection // Schema returns the schema of collection
func (c *Collection) Schema() *schemapb.CollectionSchema { func (c *Collection) Schema() *schemapb.CollectionSchema {
return c.schema.Load() return c.schema.Load()
@ -254,23 +250,12 @@ func NewCollection(collectionID int64, schema *schemapb.CollectionSchema, indexM
loadFieldIDs = typeutil.NewSet(lo.Map(loadSchema.GetFields(), func(field *schemapb.FieldSchema, _ int) int64 { return field.GetFieldID() })...) loadFieldIDs = typeutil.NewSet(lo.Map(loadSchema.GetFields(), func(field *schemapb.FieldSchema, _ int) int64 { return field.GetFieldID() })...)
} }
schemaBlob, err := proto.Marshal(loadSchema)
if err != nil {
log.Warn("marshal schema failed", zap.Error(err))
return nil
}
collection := C.NewCollection(unsafe.Pointer(&schemaBlob[0]), (C.int64_t)(len(schemaBlob)))
isGpuIndex := false isGpuIndex := false
req := &segcore.CreateCCollectionRequest{
Schema: loadSchema,
}
if indexMeta != nil && len(indexMeta.GetIndexMetas()) > 0 && indexMeta.GetMaxIndexRowCount() > 0 { if indexMeta != nil && len(indexMeta.GetIndexMetas()) > 0 && indexMeta.GetMaxIndexRowCount() > 0 {
indexMetaBlob, err := proto.Marshal(indexMeta) req.IndexMeta = indexMeta
if err != nil {
log.Warn("marshal index meta failed", zap.Error(err))
return nil
}
C.SetIndexMeta(collection, unsafe.Pointer(&indexMetaBlob[0]), (C.int64_t)(len(indexMetaBlob)))
for _, indexMeta := range indexMeta.GetIndexMetas() { for _, indexMeta := range indexMeta.GetIndexMetas() {
isGpuIndex = lo.ContainsBy(indexMeta.GetIndexParams(), func(param *commonpb.KeyValuePair) bool { isGpuIndex = lo.ContainsBy(indexMeta.GetIndexParams(), func(param *commonpb.KeyValuePair) bool {
return param.Key == common.IndexTypeKey && vecindexmgr.GetVecIndexMgrInstance().IsGPUVecIndex(param.Value) return param.Key == common.IndexTypeKey && vecindexmgr.GetVecIndexMgrInstance().IsGPUVecIndex(param.Value)
@ -281,8 +266,13 @@ func NewCollection(collectionID int64, schema *schemapb.CollectionSchema, indexM
} }
} }
ccollection, err := segcore.CreateCCollection(req)
if err != nil {
log.Warn("create collection failed", zap.Error(err))
return nil
}
coll := &Collection{ coll := &Collection{
collectionPtr: collection, ccollection: ccollection,
id: collectionID, id: collectionID,
partitions: typeutil.NewConcurrentSet[int64](), partitions: typeutil.NewConcurrentSet[int64](),
loadType: loadMetaInfo.GetLoadType(), loadType: loadMetaInfo.GetLoadType(),
@ -330,10 +320,9 @@ func DeleteCollection(collection *Collection) {
collection.mu.Lock() collection.mu.Lock()
defer collection.mu.Unlock() defer collection.mu.Unlock()
cPtr := collection.collectionPtr if collection.ccollection == nil {
if cPtr != nil { return
C.DeleteCollection(cPtr)
} }
collection.ccollection.Release()
collection.collectionPtr = nil collection.ccollection = nil
} }

View File

@ -6,6 +6,7 @@ import (
"github.com/milvus-io/milvus/internal/proto/internalpb" "github.com/milvus-io/milvus/internal/proto/internalpb"
"github.com/milvus-io/milvus/internal/proto/segcorepb" "github.com/milvus-io/milvus/internal/proto/segcorepb"
"github.com/milvus-io/milvus/internal/util/funcutil" "github.com/milvus-io/milvus/internal/util/funcutil"
"github.com/milvus-io/milvus/internal/util/segcore"
) )
type cntReducer struct{} type cntReducer struct{}
@ -33,7 +34,7 @@ func (r *cntReducer) Reduce(ctx context.Context, results []*internalpb.RetrieveR
type cntReducerSegCore struct{} type cntReducerSegCore struct{}
func (r *cntReducerSegCore) Reduce(ctx context.Context, results []*segcorepb.RetrieveResults, _ []Segment, _ *RetrievePlan) (*segcorepb.RetrieveResults, error) { func (r *cntReducerSegCore) Reduce(ctx context.Context, results []*segcorepb.RetrieveResults, _ []Segment, _ *segcore.RetrievePlan) (*segcorepb.RetrieveResults, error) {
cnt := int64(0) cnt := int64(0)
allRetrieveCount := int64(0) allRetrieveCount := int64(0)
for _, res := range results { for _, res := range results {

View File

@ -10,6 +10,7 @@ import (
"github.com/milvus-io/milvus-proto/go-api/v2/commonpb" "github.com/milvus-io/milvus-proto/go-api/v2/commonpb"
"github.com/milvus-io/milvus-proto/go-api/v2/schemapb" "github.com/milvus-io/milvus-proto/go-api/v2/schemapb"
"github.com/milvus-io/milvus/internal/mocks/util/mock_segcore"
"github.com/milvus-io/milvus/internal/proto/datapb" "github.com/milvus-io/milvus/internal/proto/datapb"
"github.com/milvus-io/milvus/internal/proto/querypb" "github.com/milvus-io/milvus/internal/proto/querypb"
"github.com/milvus-io/milvus/internal/util/initcore" "github.com/milvus-io/milvus/internal/util/initcore"
@ -50,10 +51,10 @@ func (s *ManagerSuite) SetupTest() {
s.segments = nil s.segments = nil
for i, id := range s.segmentIDs { for i, id := range s.segmentIDs {
schema := GenTestCollectionSchema("manager-suite", schemapb.DataType_Int64, true) schema := mock_segcore.GenTestCollectionSchema("manager-suite", schemapb.DataType_Int64, true)
segment, err := NewSegment( segment, err := NewSegment(
context.Background(), context.Background(),
NewCollection(s.collectionIDs[i], schema, GenTestIndexMeta(s.collectionIDs[i], schema), &querypb.LoadMetaInfo{ NewCollection(s.collectionIDs[i], schema, mock_segcore.GenTestIndexMeta(s.collectionIDs[i], schema), &querypb.LoadMetaInfo{
LoadType: querypb.LoadType_LoadCollection, LoadType: querypb.LoadType_LoadCollection,
}), }),
s.types[i], s.types[i],

View File

@ -17,6 +17,8 @@ import (
querypb "github.com/milvus-io/milvus/internal/proto/querypb" querypb "github.com/milvus-io/milvus/internal/proto/querypb"
segcore "github.com/milvus-io/milvus/internal/util/segcore"
segcorepb "github.com/milvus-io/milvus/internal/proto/segcorepb" segcorepb "github.com/milvus-io/milvus/internal/proto/segcorepb"
storage "github.com/milvus-io/milvus/internal/storage" storage "github.com/milvus-io/milvus/internal/storage"
@ -1358,7 +1360,7 @@ func (_c *MockSegment_ResourceUsageEstimate_Call) RunAndReturn(run func() Resour
} }
// Retrieve provides a mock function with given fields: ctx, plan // Retrieve provides a mock function with given fields: ctx, plan
func (_m *MockSegment) Retrieve(ctx context.Context, plan *RetrievePlan) (*segcorepb.RetrieveResults, error) { func (_m *MockSegment) Retrieve(ctx context.Context, plan *segcore.RetrievePlan) (*segcorepb.RetrieveResults, error) {
ret := _m.Called(ctx, plan) ret := _m.Called(ctx, plan)
if len(ret) == 0 { if len(ret) == 0 {
@ -1367,10 +1369,10 @@ func (_m *MockSegment) Retrieve(ctx context.Context, plan *RetrievePlan) (*segco
var r0 *segcorepb.RetrieveResults var r0 *segcorepb.RetrieveResults
var r1 error var r1 error
if rf, ok := ret.Get(0).(func(context.Context, *RetrievePlan) (*segcorepb.RetrieveResults, error)); ok { if rf, ok := ret.Get(0).(func(context.Context, *segcore.RetrievePlan) (*segcorepb.RetrieveResults, error)); ok {
return rf(ctx, plan) return rf(ctx, plan)
} }
if rf, ok := ret.Get(0).(func(context.Context, *RetrievePlan) *segcorepb.RetrieveResults); ok { if rf, ok := ret.Get(0).(func(context.Context, *segcore.RetrievePlan) *segcorepb.RetrieveResults); ok {
r0 = rf(ctx, plan) r0 = rf(ctx, plan)
} else { } else {
if ret.Get(0) != nil { if ret.Get(0) != nil {
@ -1378,7 +1380,7 @@ func (_m *MockSegment) Retrieve(ctx context.Context, plan *RetrievePlan) (*segco
} }
} }
if rf, ok := ret.Get(1).(func(context.Context, *RetrievePlan) error); ok { if rf, ok := ret.Get(1).(func(context.Context, *segcore.RetrievePlan) error); ok {
r1 = rf(ctx, plan) r1 = rf(ctx, plan)
} else { } else {
r1 = ret.Error(1) r1 = ret.Error(1)
@ -1394,14 +1396,14 @@ type MockSegment_Retrieve_Call struct {
// Retrieve is a helper method to define mock.On call // Retrieve is a helper method to define mock.On call
// - ctx context.Context // - ctx context.Context
// - plan *RetrievePlan // - plan *segcore.RetrievePlan
func (_e *MockSegment_Expecter) Retrieve(ctx interface{}, plan interface{}) *MockSegment_Retrieve_Call { func (_e *MockSegment_Expecter) Retrieve(ctx interface{}, plan interface{}) *MockSegment_Retrieve_Call {
return &MockSegment_Retrieve_Call{Call: _e.mock.On("Retrieve", ctx, plan)} return &MockSegment_Retrieve_Call{Call: _e.mock.On("Retrieve", ctx, plan)}
} }
func (_c *MockSegment_Retrieve_Call) Run(run func(ctx context.Context, plan *RetrievePlan)) *MockSegment_Retrieve_Call { func (_c *MockSegment_Retrieve_Call) Run(run func(ctx context.Context, plan *segcore.RetrievePlan)) *MockSegment_Retrieve_Call {
_c.Call.Run(func(args mock.Arguments) { _c.Call.Run(func(args mock.Arguments) {
run(args[0].(context.Context), args[1].(*RetrievePlan)) run(args[0].(context.Context), args[1].(*segcore.RetrievePlan))
}) })
return _c return _c
} }
@ -1411,14 +1413,14 @@ func (_c *MockSegment_Retrieve_Call) Return(_a0 *segcorepb.RetrieveResults, _a1
return _c return _c
} }
func (_c *MockSegment_Retrieve_Call) RunAndReturn(run func(context.Context, *RetrievePlan) (*segcorepb.RetrieveResults, error)) *MockSegment_Retrieve_Call { func (_c *MockSegment_Retrieve_Call) RunAndReturn(run func(context.Context, *segcore.RetrievePlan) (*segcorepb.RetrieveResults, error)) *MockSegment_Retrieve_Call {
_c.Call.Return(run) _c.Call.Return(run)
return _c return _c
} }
// RetrieveByOffsets provides a mock function with given fields: ctx, plan, offsets // RetrieveByOffsets provides a mock function with given fields: ctx, plan
func (_m *MockSegment) RetrieveByOffsets(ctx context.Context, plan *RetrievePlan, offsets []int64) (*segcorepb.RetrieveResults, error) { func (_m *MockSegment) RetrieveByOffsets(ctx context.Context, plan *segcore.RetrievePlanWithOffsets) (*segcorepb.RetrieveResults, error) {
ret := _m.Called(ctx, plan, offsets) ret := _m.Called(ctx, plan)
if len(ret) == 0 { if len(ret) == 0 {
panic("no return value specified for RetrieveByOffsets") panic("no return value specified for RetrieveByOffsets")
@ -1426,19 +1428,19 @@ func (_m *MockSegment) RetrieveByOffsets(ctx context.Context, plan *RetrievePlan
var r0 *segcorepb.RetrieveResults var r0 *segcorepb.RetrieveResults
var r1 error var r1 error
if rf, ok := ret.Get(0).(func(context.Context, *RetrievePlan, []int64) (*segcorepb.RetrieveResults, error)); ok { if rf, ok := ret.Get(0).(func(context.Context, *segcore.RetrievePlanWithOffsets) (*segcorepb.RetrieveResults, error)); ok {
return rf(ctx, plan, offsets) return rf(ctx, plan)
} }
if rf, ok := ret.Get(0).(func(context.Context, *RetrievePlan, []int64) *segcorepb.RetrieveResults); ok { if rf, ok := ret.Get(0).(func(context.Context, *segcore.RetrievePlanWithOffsets) *segcorepb.RetrieveResults); ok {
r0 = rf(ctx, plan, offsets) r0 = rf(ctx, plan)
} else { } else {
if ret.Get(0) != nil { if ret.Get(0) != nil {
r0 = ret.Get(0).(*segcorepb.RetrieveResults) r0 = ret.Get(0).(*segcorepb.RetrieveResults)
} }
} }
if rf, ok := ret.Get(1).(func(context.Context, *RetrievePlan, []int64) error); ok { if rf, ok := ret.Get(1).(func(context.Context, *segcore.RetrievePlanWithOffsets) error); ok {
r1 = rf(ctx, plan, offsets) r1 = rf(ctx, plan)
} else { } else {
r1 = ret.Error(1) r1 = ret.Error(1)
} }
@ -1453,15 +1455,14 @@ type MockSegment_RetrieveByOffsets_Call struct {
// RetrieveByOffsets is a helper method to define mock.On call // RetrieveByOffsets is a helper method to define mock.On call
// - ctx context.Context // - ctx context.Context
// - plan *RetrievePlan // - plan *segcore.RetrievePlanWithOffsets
// - offsets []int64 func (_e *MockSegment_Expecter) RetrieveByOffsets(ctx interface{}, plan interface{}) *MockSegment_RetrieveByOffsets_Call {
func (_e *MockSegment_Expecter) RetrieveByOffsets(ctx interface{}, plan interface{}, offsets interface{}) *MockSegment_RetrieveByOffsets_Call { return &MockSegment_RetrieveByOffsets_Call{Call: _e.mock.On("RetrieveByOffsets", ctx, plan)}
return &MockSegment_RetrieveByOffsets_Call{Call: _e.mock.On("RetrieveByOffsets", ctx, plan, offsets)}
} }
func (_c *MockSegment_RetrieveByOffsets_Call) Run(run func(ctx context.Context, plan *RetrievePlan, offsets []int64)) *MockSegment_RetrieveByOffsets_Call { func (_c *MockSegment_RetrieveByOffsets_Call) Run(run func(ctx context.Context, plan *segcore.RetrievePlanWithOffsets)) *MockSegment_RetrieveByOffsets_Call {
_c.Call.Run(func(args mock.Arguments) { _c.Call.Run(func(args mock.Arguments) {
run(args[0].(context.Context), args[1].(*RetrievePlan), args[2].([]int64)) run(args[0].(context.Context), args[1].(*segcore.RetrievePlanWithOffsets))
}) })
return _c return _c
} }
@ -1471,7 +1472,7 @@ func (_c *MockSegment_RetrieveByOffsets_Call) Return(_a0 *segcorepb.RetrieveResu
return _c return _c
} }
func (_c *MockSegment_RetrieveByOffsets_Call) RunAndReturn(run func(context.Context, *RetrievePlan, []int64) (*segcorepb.RetrieveResults, error)) *MockSegment_RetrieveByOffsets_Call { func (_c *MockSegment_RetrieveByOffsets_Call) RunAndReturn(run func(context.Context, *segcore.RetrievePlanWithOffsets) (*segcorepb.RetrieveResults, error)) *MockSegment_RetrieveByOffsets_Call {
_c.Call.Return(run) _c.Call.Return(run)
return _c return _c
} }
@ -1522,27 +1523,27 @@ func (_c *MockSegment_RowNum_Call) RunAndReturn(run func() int64) *MockSegment_R
} }
// Search provides a mock function with given fields: ctx, searchReq // Search provides a mock function with given fields: ctx, searchReq
func (_m *MockSegment) Search(ctx context.Context, searchReq *SearchRequest) (*SearchResult, error) { func (_m *MockSegment) Search(ctx context.Context, searchReq *segcore.SearchRequest) (*segcore.SearchResult, error) {
ret := _m.Called(ctx, searchReq) ret := _m.Called(ctx, searchReq)
if len(ret) == 0 { if len(ret) == 0 {
panic("no return value specified for Search") panic("no return value specified for Search")
} }
var r0 *SearchResult var r0 *segcore.SearchResult
var r1 error var r1 error
if rf, ok := ret.Get(0).(func(context.Context, *SearchRequest) (*SearchResult, error)); ok { if rf, ok := ret.Get(0).(func(context.Context, *segcore.SearchRequest) (*segcore.SearchResult, error)); ok {
return rf(ctx, searchReq) return rf(ctx, searchReq)
} }
if rf, ok := ret.Get(0).(func(context.Context, *SearchRequest) *SearchResult); ok { if rf, ok := ret.Get(0).(func(context.Context, *segcore.SearchRequest) *segcore.SearchResult); ok {
r0 = rf(ctx, searchReq) r0 = rf(ctx, searchReq)
} else { } else {
if ret.Get(0) != nil { if ret.Get(0) != nil {
r0 = ret.Get(0).(*SearchResult) r0 = ret.Get(0).(*segcore.SearchResult)
} }
} }
if rf, ok := ret.Get(1).(func(context.Context, *SearchRequest) error); ok { if rf, ok := ret.Get(1).(func(context.Context, *segcore.SearchRequest) error); ok {
r1 = rf(ctx, searchReq) r1 = rf(ctx, searchReq)
} else { } else {
r1 = ret.Error(1) r1 = ret.Error(1)
@ -1558,24 +1559,24 @@ type MockSegment_Search_Call struct {
// Search is a helper method to define mock.On call // Search is a helper method to define mock.On call
// - ctx context.Context // - ctx context.Context
// - searchReq *SearchRequest // - searchReq *segcore.SearchRequest
func (_e *MockSegment_Expecter) Search(ctx interface{}, searchReq interface{}) *MockSegment_Search_Call { func (_e *MockSegment_Expecter) Search(ctx interface{}, searchReq interface{}) *MockSegment_Search_Call {
return &MockSegment_Search_Call{Call: _e.mock.On("Search", ctx, searchReq)} return &MockSegment_Search_Call{Call: _e.mock.On("Search", ctx, searchReq)}
} }
func (_c *MockSegment_Search_Call) Run(run func(ctx context.Context, searchReq *SearchRequest)) *MockSegment_Search_Call { func (_c *MockSegment_Search_Call) Run(run func(ctx context.Context, searchReq *segcore.SearchRequest)) *MockSegment_Search_Call {
_c.Call.Run(func(args mock.Arguments) { _c.Call.Run(func(args mock.Arguments) {
run(args[0].(context.Context), args[1].(*SearchRequest)) run(args[0].(context.Context), args[1].(*segcore.SearchRequest))
}) })
return _c return _c
} }
func (_c *MockSegment_Search_Call) Return(_a0 *SearchResult, _a1 error) *MockSegment_Search_Call { func (_c *MockSegment_Search_Call) Return(_a0 *segcore.SearchResult, _a1 error) *MockSegment_Search_Call {
_c.Call.Return(_a0, _a1) _c.Call.Return(_a0, _a1)
return _c return _c
} }
func (_c *MockSegment_Search_Call) RunAndReturn(run func(context.Context, *SearchRequest) (*SearchResult, error)) *MockSegment_Search_Call { func (_c *MockSegment_Search_Call) RunAndReturn(run func(context.Context, *segcore.SearchRequest) (*segcore.SearchResult, error)) *MockSegment_Search_Call {
_c.Call.Return(run) _c.Call.Return(run)
return _c return _c
} }

View File

@ -9,6 +9,7 @@ import (
"github.com/milvus-io/milvus/internal/proto/internalpb" "github.com/milvus-io/milvus/internal/proto/internalpb"
"github.com/milvus-io/milvus/internal/proto/querypb" "github.com/milvus-io/milvus/internal/proto/querypb"
"github.com/milvus-io/milvus/internal/proto/segcorepb" "github.com/milvus-io/milvus/internal/proto/segcorepb"
"github.com/milvus-io/milvus/internal/util/segcore"
"github.com/milvus-io/milvus/pkg/common" "github.com/milvus-io/milvus/pkg/common"
"github.com/milvus-io/milvus/pkg/util/merr" "github.com/milvus-io/milvus/pkg/util/merr"
"github.com/milvus-io/milvus/pkg/util/typeutil" "github.com/milvus-io/milvus/pkg/util/typeutil"
@ -26,7 +27,7 @@ func CreateInternalReducer(req *querypb.QueryRequest, schema *schemapb.Collectio
} }
type segCoreReducer interface { type segCoreReducer interface {
Reduce(context.Context, []*segcorepb.RetrieveResults, []Segment, *RetrievePlan) (*segcorepb.RetrieveResults, error) Reduce(context.Context, []*segcorepb.RetrieveResults, []Segment, *segcore.RetrievePlan) (*segcorepb.RetrieveResults, error)
} }
func CreateSegCoreReducer(req *querypb.QueryRequest, schema *schemapb.CollectionSchema, manager *Manager) segCoreReducer { func CreateSegCoreReducer(req *querypb.QueryRequest, schema *schemapb.CollectionSchema, manager *Manager) segCoreReducer {

View File

@ -30,6 +30,7 @@ import (
"github.com/milvus-io/milvus/internal/proto/internalpb" "github.com/milvus-io/milvus/internal/proto/internalpb"
"github.com/milvus-io/milvus/internal/proto/segcorepb" "github.com/milvus-io/milvus/internal/proto/segcorepb"
"github.com/milvus-io/milvus/internal/util/reduce" "github.com/milvus-io/milvus/internal/util/reduce"
"github.com/milvus-io/milvus/internal/util/segcore"
typeutil2 "github.com/milvus-io/milvus/internal/util/typeutil" typeutil2 "github.com/milvus-io/milvus/internal/util/typeutil"
"github.com/milvus-io/milvus/pkg/common" "github.com/milvus-io/milvus/pkg/common"
"github.com/milvus-io/milvus/pkg/log" "github.com/milvus-io/milvus/pkg/log"
@ -413,7 +414,7 @@ func MergeSegcoreRetrieveResults(ctx context.Context, retrieveResults []*segcore
return nil, err return nil, err
} }
validRetrieveResults = append(validRetrieveResults, tr) validRetrieveResults = append(validRetrieveResults, tr)
if plan.ignoreNonPk { if plan.IsIgnoreNonPk() {
validSegments = append(validSegments, segments[i]) validSegments = append(validSegments, segments[i])
} }
loopEnd += size loopEnd += size
@ -493,7 +494,7 @@ func MergeSegcoreRetrieveResults(ctx context.Context, retrieveResults []*segcore
log.Debug("skip duplicated query result while reducing segcore.RetrieveResults", zap.Int64("dupCount", skipDupCnt)) log.Debug("skip duplicated query result while reducing segcore.RetrieveResults", zap.Int64("dupCount", skipDupCnt))
} }
if !plan.ignoreNonPk { if !plan.IsIgnoreNonPk() {
// target entry already retrieved, don't do this after AppendPKs for better performance. Save the cost everytime // target entry already retrieved, don't do this after AppendPKs for better performance. Save the cost everytime
// judge the `!plan.ignoreNonPk` condition. // judge the `!plan.ignoreNonPk` condition.
_, span2 := otel.Tracer(typeutil.QueryNodeRole).Start(ctx, "MergeSegcoreResults-AppendFieldData") _, span2 := otel.Tracer(typeutil.QueryNodeRole).Start(ctx, "MergeSegcoreResults-AppendFieldData")
@ -524,7 +525,10 @@ func MergeSegcoreRetrieveResults(ctx context.Context, retrieveResults []*segcore
var r *segcorepb.RetrieveResults var r *segcorepb.RetrieveResults
var err error var err error
if err := doOnSegment(ctx, manager, validSegments[idx], func(ctx context.Context, segment Segment) error { if err := doOnSegment(ctx, manager, validSegments[idx], func(ctx context.Context, segment Segment) error {
r, err = segment.RetrieveByOffsets(ctx, plan, theOffsets) r, err = segment.RetrieveByOffsets(ctx, &segcore.RetrievePlanWithOffsets{
RetrievePlan: plan,
Offsets: theOffsets,
})
return err return err
}); err != nil { }); err != nil {
return nil, err return nil, err

View File

@ -27,9 +27,11 @@ import (
"github.com/stretchr/testify/suite" "github.com/stretchr/testify/suite"
"github.com/milvus-io/milvus-proto/go-api/v2/schemapb" "github.com/milvus-io/milvus-proto/go-api/v2/schemapb"
"github.com/milvus-io/milvus/internal/mocks/util/mock_segcore"
"github.com/milvus-io/milvus/internal/proto/internalpb" "github.com/milvus-io/milvus/internal/proto/internalpb"
"github.com/milvus-io/milvus/internal/proto/segcorepb" "github.com/milvus-io/milvus/internal/proto/segcorepb"
"github.com/milvus-io/milvus/internal/util/reduce" "github.com/milvus-io/milvus/internal/util/reduce"
"github.com/milvus-io/milvus/internal/util/segcore"
"github.com/milvus-io/milvus/pkg/common" "github.com/milvus-io/milvus/pkg/common"
"github.com/milvus-io/milvus/pkg/util/metric" "github.com/milvus-io/milvus/pkg/util/metric"
"github.com/milvus-io/milvus/pkg/util/paramtable" "github.com/milvus-io/milvus/pkg/util/paramtable"
@ -50,7 +52,7 @@ type ResultSuite struct {
} }
func MergeSegcoreRetrieveResultsV1(ctx context.Context, retrieveResults []*segcorepb.RetrieveResults, param *mergeParam) (*segcorepb.RetrieveResults, error) { func MergeSegcoreRetrieveResultsV1(ctx context.Context, retrieveResults []*segcorepb.RetrieveResults, param *mergeParam) (*segcorepb.RetrieveResults, error) {
plan := &RetrievePlan{ignoreNonPk: false} plan := &segcore.RetrievePlan{}
return MergeSegcoreRetrieveResults(ctx, retrieveResults, param, nil, plan, nil) return MergeSegcoreRetrieveResults(ctx, retrieveResults, param, nil, plan, nil)
} }
@ -66,14 +68,14 @@ func (suite *ResultSuite) TestResult_MergeSegcoreRetrieveResults() {
FloatVector := []float32{1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 11.0, 22.0, 33.0, 44.0, 55.0, 66.0, 77.0, 88.0} FloatVector := []float32{1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 11.0, 22.0, 33.0, 44.0, 55.0, 66.0, 77.0, 88.0}
var fieldDataArray1 []*schemapb.FieldData var fieldDataArray1 []*schemapb.FieldData
fieldDataArray1 = append(fieldDataArray1, genFieldData(common.TimeStampFieldName, common.TimeStampField, schemapb.DataType_Int64, []int64{1000, 2000}, 1)) fieldDataArray1 = append(fieldDataArray1, mock_segcore.GenFieldData(common.TimeStampFieldName, common.TimeStampField, schemapb.DataType_Int64, []int64{1000, 2000}, 1))
fieldDataArray1 = append(fieldDataArray1, genFieldData(Int64FieldName, Int64FieldID, schemapb.DataType_Int64, Int64Array[0:2], 1)) fieldDataArray1 = append(fieldDataArray1, mock_segcore.GenFieldData(Int64FieldName, Int64FieldID, schemapb.DataType_Int64, Int64Array[0:2], 1))
fieldDataArray1 = append(fieldDataArray1, genFieldData(FloatVectorFieldName, FloatVectorFieldID, schemapb.DataType_FloatVector, FloatVector[0:16], Dim)) fieldDataArray1 = append(fieldDataArray1, mock_segcore.GenFieldData(FloatVectorFieldName, FloatVectorFieldID, schemapb.DataType_FloatVector, FloatVector[0:16], Dim))
var fieldDataArray2 []*schemapb.FieldData var fieldDataArray2 []*schemapb.FieldData
fieldDataArray2 = append(fieldDataArray2, genFieldData(common.TimeStampFieldName, common.TimeStampField, schemapb.DataType_Int64, []int64{2000, 3000}, 1)) fieldDataArray2 = append(fieldDataArray2, mock_segcore.GenFieldData(common.TimeStampFieldName, common.TimeStampField, schemapb.DataType_Int64, []int64{2000, 3000}, 1))
fieldDataArray2 = append(fieldDataArray2, genFieldData(Int64FieldName, Int64FieldID, schemapb.DataType_Int64, Int64Array[0:2], 1)) fieldDataArray2 = append(fieldDataArray2, mock_segcore.GenFieldData(Int64FieldName, Int64FieldID, schemapb.DataType_Int64, Int64Array[0:2], 1))
fieldDataArray2 = append(fieldDataArray2, genFieldData(FloatVectorFieldName, FloatVectorFieldID, schemapb.DataType_FloatVector, FloatVector[0:16], Dim)) fieldDataArray2 = append(fieldDataArray2, mock_segcore.GenFieldData(FloatVectorFieldName, FloatVectorFieldID, schemapb.DataType_FloatVector, FloatVector[0:16], Dim))
suite.Run("test skip dupPK 2", func() { suite.Run("test skip dupPK 2", func() {
result1 := &segcorepb.RetrieveResults{ result1 := &segcorepb.RetrieveResults{
@ -114,14 +116,14 @@ func (suite *ResultSuite) TestResult_MergeSegcoreRetrieveResults() {
suite.Run("test_duppk_multipke_segment", func() { suite.Run("test_duppk_multipke_segment", func() {
var fieldsData1 []*schemapb.FieldData var fieldsData1 []*schemapb.FieldData
fieldsData1 = append(fieldsData1, genFieldData(common.TimeStampFieldName, common.TimeStampField, schemapb.DataType_Int64, []int64{2000, 3000}, 1)) fieldsData1 = append(fieldsData1, mock_segcore.GenFieldData(common.TimeStampFieldName, common.TimeStampField, schemapb.DataType_Int64, []int64{2000, 3000}, 1))
fieldsData1 = append(fieldsData1, genFieldData(Int64FieldName, Int64FieldID, schemapb.DataType_Int64, []int64{1, 1}, 1)) fieldsData1 = append(fieldsData1, mock_segcore.GenFieldData(Int64FieldName, Int64FieldID, schemapb.DataType_Int64, []int64{1, 1}, 1))
fieldsData1 = append(fieldsData1, genFieldData(FloatVectorFieldName, FloatVectorFieldID, schemapb.DataType_FloatVector, FloatVector[0:16], Dim)) fieldsData1 = append(fieldsData1, mock_segcore.GenFieldData(FloatVectorFieldName, FloatVectorFieldID, schemapb.DataType_FloatVector, FloatVector[0:16], Dim))
var fieldsData2 []*schemapb.FieldData var fieldsData2 []*schemapb.FieldData
fieldsData2 = append(fieldsData2, genFieldData(common.TimeStampFieldName, common.TimeStampField, schemapb.DataType_Int64, []int64{2500}, 1)) fieldsData2 = append(fieldsData2, mock_segcore.GenFieldData(common.TimeStampFieldName, common.TimeStampField, schemapb.DataType_Int64, []int64{2500}, 1))
fieldsData2 = append(fieldsData2, genFieldData(Int64FieldName, Int64FieldID, schemapb.DataType_Int64, []int64{1}, 1)) fieldsData2 = append(fieldsData2, mock_segcore.GenFieldData(Int64FieldName, Int64FieldID, schemapb.DataType_Int64, []int64{1}, 1))
fieldsData2 = append(fieldsData2, genFieldData(FloatVectorFieldName, FloatVectorFieldID, schemapb.DataType_FloatVector, FloatVector[0:8], Dim)) fieldsData2 = append(fieldsData2, mock_segcore.GenFieldData(FloatVectorFieldName, FloatVectorFieldID, schemapb.DataType_FloatVector, FloatVector[0:8], Dim))
result1 := &segcorepb.RetrieveResults{ result1 := &segcorepb.RetrieveResults{
Ids: &schemapb.IDs{ Ids: &schemapb.IDs{
@ -254,7 +256,7 @@ func (suite *ResultSuite) TestResult_MergeSegcoreRetrieveResults() {
ids[i] = int64(i) ids[i] = int64(i)
offsets[i] = int64(i) offsets[i] = int64(i)
} }
fieldData := genFieldData(Int64FieldName, Int64FieldID, schemapb.DataType_Int64, ids, 1) fieldData := mock_segcore.GenFieldData(Int64FieldName, Int64FieldID, schemapb.DataType_Int64, ids, 1)
result := &segcorepb.RetrieveResults{ result := &segcorepb.RetrieveResults{
Ids: &schemapb.IDs{ Ids: &schemapb.IDs{
@ -333,14 +335,14 @@ func (suite *ResultSuite) TestResult_MergeInternalRetrieveResults() {
FloatVector := []float32{1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 11.0, 22.0, 33.0, 44.0, 55.0, 66.0, 77.0, 88.0} FloatVector := []float32{1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 11.0, 22.0, 33.0, 44.0, 55.0, 66.0, 77.0, 88.0}
var fieldDataArray1 []*schemapb.FieldData var fieldDataArray1 []*schemapb.FieldData
fieldDataArray1 = append(fieldDataArray1, genFieldData(common.TimeStampFieldName, common.TimeStampField, schemapb.DataType_Int64, []int64{1000, 2000}, 1)) fieldDataArray1 = append(fieldDataArray1, mock_segcore.GenFieldData(common.TimeStampFieldName, common.TimeStampField, schemapb.DataType_Int64, []int64{1000, 2000}, 1))
fieldDataArray1 = append(fieldDataArray1, genFieldData(Int64FieldName, Int64FieldID, schemapb.DataType_Int64, Int64Array[0:2], 1)) fieldDataArray1 = append(fieldDataArray1, mock_segcore.GenFieldData(Int64FieldName, Int64FieldID, schemapb.DataType_Int64, Int64Array[0:2], 1))
fieldDataArray1 = append(fieldDataArray1, genFieldData(FloatVectorFieldName, FloatVectorFieldID, schemapb.DataType_FloatVector, FloatVector[0:16], Dim)) fieldDataArray1 = append(fieldDataArray1, mock_segcore.GenFieldData(FloatVectorFieldName, FloatVectorFieldID, schemapb.DataType_FloatVector, FloatVector[0:16], Dim))
var fieldDataArray2 []*schemapb.FieldData var fieldDataArray2 []*schemapb.FieldData
fieldDataArray2 = append(fieldDataArray2, genFieldData(common.TimeStampFieldName, common.TimeStampField, schemapb.DataType_Int64, []int64{2000, 3000}, 1)) fieldDataArray2 = append(fieldDataArray2, mock_segcore.GenFieldData(common.TimeStampFieldName, common.TimeStampField, schemapb.DataType_Int64, []int64{2000, 3000}, 1))
fieldDataArray2 = append(fieldDataArray2, genFieldData(Int64FieldName, Int64FieldID, schemapb.DataType_Int64, Int64Array[0:2], 1)) fieldDataArray2 = append(fieldDataArray2, mock_segcore.GenFieldData(Int64FieldName, Int64FieldID, schemapb.DataType_Int64, Int64Array[0:2], 1))
fieldDataArray2 = append(fieldDataArray2, genFieldData(FloatVectorFieldName, FloatVectorFieldID, schemapb.DataType_FloatVector, FloatVector[0:16], Dim)) fieldDataArray2 = append(fieldDataArray2, mock_segcore.GenFieldData(FloatVectorFieldName, FloatVectorFieldID, schemapb.DataType_FloatVector, FloatVector[0:16], Dim))
suite.Run("test skip dupPK 2", func() { suite.Run("test skip dupPK 2", func() {
result1 := &internalpb.RetrieveResults{ result1 := &internalpb.RetrieveResults{
@ -395,9 +397,9 @@ func (suite *ResultSuite) TestResult_MergeInternalRetrieveResults() {
}, },
}, },
FieldsData: []*schemapb.FieldData{ FieldsData: []*schemapb.FieldData{
genFieldData(common.TimeStampFieldName, common.TimeStampField, schemapb.DataType_Int64, mock_segcore.GenFieldData(common.TimeStampFieldName, common.TimeStampField, schemapb.DataType_Int64,
[]int64{1, 2}, 1), []int64{1, 2}, 1),
genFieldData(Int64FieldName, Int64FieldID, schemapb.DataType_Int64, mock_segcore.GenFieldData(Int64FieldName, Int64FieldID, schemapb.DataType_Int64,
[]int64{3, 4}, 1), []int64{3, 4}, 1),
}, },
} }
@ -410,9 +412,9 @@ func (suite *ResultSuite) TestResult_MergeInternalRetrieveResults() {
}, },
}, },
FieldsData: []*schemapb.FieldData{ FieldsData: []*schemapb.FieldData{
genFieldData(common.TimeStampFieldName, common.TimeStampField, schemapb.DataType_Int64, mock_segcore.GenFieldData(common.TimeStampFieldName, common.TimeStampField, schemapb.DataType_Int64,
[]int64{5, 6}, 1), []int64{5, 6}, 1),
genFieldData(Int64FieldName, Int64FieldID, schemapb.DataType_Int64, mock_segcore.GenFieldData(Int64FieldName, Int64FieldID, schemapb.DataType_Int64,
[]int64{7, 8}, 1), []int64{7, 8}, 1),
}, },
} }
@ -493,7 +495,7 @@ func (suite *ResultSuite) TestResult_MergeInternalRetrieveResults() {
ids[i] = int64(i) ids[i] = int64(i)
offsets[i] = int64(i) offsets[i] = int64(i)
} }
fieldData := genFieldData(Int64FieldName, Int64FieldID, schemapb.DataType_Int64, ids, 1) fieldData := mock_segcore.GenFieldData(Int64FieldName, Int64FieldID, schemapb.DataType_Int64, ids, 1)
result := &internalpb.RetrieveResults{ result := &internalpb.RetrieveResults{
Ids: &schemapb.IDs{ Ids: &schemapb.IDs{
@ -572,17 +574,17 @@ func (suite *ResultSuite) TestResult_MergeStopForBestResult() {
FloatVector := []float32{1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 11.0, 22.0, 33.0, 44.0} FloatVector := []float32{1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 11.0, 22.0, 33.0, 44.0}
var fieldDataArray1 []*schemapb.FieldData var fieldDataArray1 []*schemapb.FieldData
fieldDataArray1 = append(fieldDataArray1, genFieldData(common.TimeStampFieldName, common.TimeStampField, schemapb.DataType_Int64, []int64{1000, 2000, 3000}, 1)) fieldDataArray1 = append(fieldDataArray1, mock_segcore.GenFieldData(common.TimeStampFieldName, common.TimeStampField, schemapb.DataType_Int64, []int64{1000, 2000, 3000}, 1))
fieldDataArray1 = append(fieldDataArray1, genFieldData(Int64FieldName, Int64FieldID, fieldDataArray1 = append(fieldDataArray1, mock_segcore.GenFieldData(Int64FieldName, Int64FieldID,
schemapb.DataType_Int64, Int64Array[0:3], 1)) schemapb.DataType_Int64, Int64Array[0:3], 1))
fieldDataArray1 = append(fieldDataArray1, genFieldData(FloatVectorFieldName, FloatVectorFieldID, fieldDataArray1 = append(fieldDataArray1, mock_segcore.GenFieldData(FloatVectorFieldName, FloatVectorFieldID,
schemapb.DataType_FloatVector, FloatVector[0:12], Dim)) schemapb.DataType_FloatVector, FloatVector[0:12], Dim))
var fieldDataArray2 []*schemapb.FieldData var fieldDataArray2 []*schemapb.FieldData
fieldDataArray2 = append(fieldDataArray2, genFieldData(common.TimeStampFieldName, common.TimeStampField, schemapb.DataType_Int64, []int64{2000, 3000, 4000}, 1)) fieldDataArray2 = append(fieldDataArray2, mock_segcore.GenFieldData(common.TimeStampFieldName, common.TimeStampField, schemapb.DataType_Int64, []int64{2000, 3000, 4000}, 1))
fieldDataArray2 = append(fieldDataArray2, genFieldData(Int64FieldName, Int64FieldID, fieldDataArray2 = append(fieldDataArray2, mock_segcore.GenFieldData(Int64FieldName, Int64FieldID,
schemapb.DataType_Int64, Int64Array[0:3], 1)) schemapb.DataType_Int64, Int64Array[0:3], 1))
fieldDataArray2 = append(fieldDataArray2, genFieldData(FloatVectorFieldName, FloatVectorFieldID, fieldDataArray2 = append(fieldDataArray2, mock_segcore.GenFieldData(FloatVectorFieldName, FloatVectorFieldID,
schemapb.DataType_FloatVector, FloatVector[0:12], Dim)) schemapb.DataType_FloatVector, FloatVector[0:12], Dim))
suite.Run("test stop seg core merge for best", func() { suite.Run("test stop seg core merge for best", func() {
@ -712,10 +714,10 @@ func (suite *ResultSuite) TestResult_MergeStopForBestResult() {
FieldsData: fieldDataArray1, FieldsData: fieldDataArray1,
} }
var drainDataArray2 []*schemapb.FieldData var drainDataArray2 []*schemapb.FieldData
drainDataArray2 = append(drainDataArray2, genFieldData(common.TimeStampFieldName, common.TimeStampField, schemapb.DataType_Int64, []int64{2000}, 1)) drainDataArray2 = append(drainDataArray2, mock_segcore.GenFieldData(common.TimeStampFieldName, common.TimeStampField, schemapb.DataType_Int64, []int64{2000}, 1))
drainDataArray2 = append(drainDataArray2, genFieldData(Int64FieldName, Int64FieldID, drainDataArray2 = append(drainDataArray2, mock_segcore.GenFieldData(Int64FieldName, Int64FieldID,
schemapb.DataType_Int64, Int64Array[0:1], 1)) schemapb.DataType_Int64, Int64Array[0:1], 1))
drainDataArray2 = append(drainDataArray2, genFieldData(FloatVectorFieldName, FloatVectorFieldID, drainDataArray2 = append(drainDataArray2, mock_segcore.GenFieldData(FloatVectorFieldName, FloatVectorFieldID,
schemapb.DataType_FloatVector, FloatVector[0:4], Dim)) schemapb.DataType_FloatVector, FloatVector[0:4], Dim))
result2 := &internalpb.RetrieveResults{ result2 := &internalpb.RetrieveResults{
Ids: &schemapb.IDs{ Ids: &schemapb.IDs{
@ -878,28 +880,28 @@ func (suite *ResultSuite) TestSort() {
}, },
Offset: []int64{5, 4, 3, 2, 9, 8, 7, 6}, Offset: []int64{5, 4, 3, 2, 9, 8, 7, 6},
FieldsData: []*schemapb.FieldData{ FieldsData: []*schemapb.FieldData{
genFieldData("int64 field", 100, schemapb.DataType_Int64, mock_segcore.GenFieldData("int64 field", 100, schemapb.DataType_Int64,
[]int64{5, 4, 3, 2, 9, 8, 7, 6}, 1), []int64{5, 4, 3, 2, 9, 8, 7, 6}, 1),
genFieldData("double field", 101, schemapb.DataType_Double, mock_segcore.GenFieldData("double field", 101, schemapb.DataType_Double,
[]float64{5, 4, 3, 2, 9, 8, 7, 6}, 1), []float64{5, 4, 3, 2, 9, 8, 7, 6}, 1),
genFieldData("string field", 102, schemapb.DataType_VarChar, mock_segcore.GenFieldData("string field", 102, schemapb.DataType_VarChar,
[]string{"5", "4", "3", "2", "9", "8", "7", "6"}, 1), []string{"5", "4", "3", "2", "9", "8", "7", "6"}, 1),
genFieldData("bool field", 103, schemapb.DataType_Bool, mock_segcore.GenFieldData("bool field", 103, schemapb.DataType_Bool,
[]bool{false, true, false, true, false, true, false, true}, 1), []bool{false, true, false, true, false, true, false, true}, 1),
genFieldData("float field", 104, schemapb.DataType_Float, mock_segcore.GenFieldData("float field", 104, schemapb.DataType_Float,
[]float32{5, 4, 3, 2, 9, 8, 7, 6}, 1), []float32{5, 4, 3, 2, 9, 8, 7, 6}, 1),
genFieldData("int field", 105, schemapb.DataType_Int32, mock_segcore.GenFieldData("int field", 105, schemapb.DataType_Int32,
[]int32{5, 4, 3, 2, 9, 8, 7, 6}, 1), []int32{5, 4, 3, 2, 9, 8, 7, 6}, 1),
genFieldData("float vector field", 106, schemapb.DataType_FloatVector, mock_segcore.GenFieldData("float vector field", 106, schemapb.DataType_FloatVector,
[]float32{5, 4, 3, 2, 9, 8, 7, 6}, 1), []float32{5, 4, 3, 2, 9, 8, 7, 6}, 1),
genFieldData("binary vector field", 107, schemapb.DataType_BinaryVector, mock_segcore.GenFieldData("binary vector field", 107, schemapb.DataType_BinaryVector,
[]byte{5, 4, 3, 2, 9, 8, 7, 6}, 8), []byte{5, 4, 3, 2, 9, 8, 7, 6}, 8),
genFieldData("json field", 108, schemapb.DataType_JSON, mock_segcore.GenFieldData("json field", 108, schemapb.DataType_JSON,
[][]byte{ [][]byte{
[]byte("{\"5\": 5}"), []byte("{\"4\": 4}"), []byte("{\"3\": 3}"), []byte("{\"2\": 2}"), []byte("{\"5\": 5}"), []byte("{\"4\": 4}"), []byte("{\"3\": 3}"), []byte("{\"2\": 2}"),
[]byte("{\"9\": 9}"), []byte("{\"8\": 8}"), []byte("{\"7\": 7}"), []byte("{\"6\": 6}"), []byte("{\"9\": 9}"), []byte("{\"8\": 8}"), []byte("{\"7\": 7}"), []byte("{\"6\": 6}"),
}, 1), }, 1),
genFieldData("json field", 108, schemapb.DataType_Array, mock_segcore.GenFieldData("json field", 108, schemapb.DataType_Array,
[]*schemapb.ScalarField{ []*schemapb.ScalarField{
{Data: &schemapb.ScalarField_IntData{IntData: &schemapb.IntArray{Data: []int32{5, 6, 7}}}}, {Data: &schemapb.ScalarField_IntData{IntData: &schemapb.IntArray{Data: []int32{5, 6, 7}}}},
{Data: &schemapb.ScalarField_IntData{IntData: &schemapb.IntArray{Data: []int32{4, 5, 6}}}}, {Data: &schemapb.ScalarField_IntData{IntData: &schemapb.IntArray{Data: []int32{4, 5, 6}}}},

View File

@ -54,7 +54,7 @@ func retrieveOnSegments(ctx context.Context, mgr *Manager, segments []Segment, s
} }
return false return false
}() }()
plan.ignoreNonPk = !anySegIsLazyLoad && len(segments) > 1 && req.GetReq().GetLimit() != typeutil.Unlimited && plan.ShouldIgnoreNonPk() plan.SetIgnoreNonPk(!anySegIsLazyLoad && len(segments) > 1 && req.GetReq().GetLimit() != typeutil.Unlimited && plan.ShouldIgnoreNonPk())
label := metrics.SealedSegmentLabel label := metrics.SealedSegmentLabel
if segType == commonpb.SegmentState_Growing { if segType == commonpb.SegmentState_Growing {

View File

@ -25,11 +25,13 @@ import (
"github.com/stretchr/testify/suite" "github.com/stretchr/testify/suite"
"github.com/milvus-io/milvus-proto/go-api/v2/schemapb" "github.com/milvus-io/milvus-proto/go-api/v2/schemapb"
"github.com/milvus-io/milvus/internal/mocks/util/mock_segcore"
"github.com/milvus-io/milvus/internal/proto/datapb" "github.com/milvus-io/milvus/internal/proto/datapb"
"github.com/milvus-io/milvus/internal/proto/internalpb" "github.com/milvus-io/milvus/internal/proto/internalpb"
"github.com/milvus-io/milvus/internal/proto/querypb" "github.com/milvus-io/milvus/internal/proto/querypb"
"github.com/milvus-io/milvus/internal/storage" "github.com/milvus-io/milvus/internal/storage"
"github.com/milvus-io/milvus/internal/util/initcore" "github.com/milvus-io/milvus/internal/util/initcore"
"github.com/milvus-io/milvus/internal/util/segcore"
"github.com/milvus-io/milvus/internal/util/streamrpc" "github.com/milvus-io/milvus/internal/util/streamrpc"
"github.com/milvus-io/milvus/pkg/util/merr" "github.com/milvus-io/milvus/pkg/util/merr"
"github.com/milvus-io/milvus/pkg/util/paramtable" "github.com/milvus-io/milvus/pkg/util/paramtable"
@ -71,8 +73,8 @@ func (suite *RetrieveSuite) SetupTest() {
suite.segmentID = 1 suite.segmentID = 1
suite.manager = NewManager() suite.manager = NewManager()
schema := GenTestCollectionSchema("test-reduce", schemapb.DataType_Int64, true) schema := mock_segcore.GenTestCollectionSchema("test-reduce", schemapb.DataType_Int64, true)
indexMeta := GenTestIndexMeta(suite.collectionID, schema) indexMeta := mock_segcore.GenTestIndexMeta(suite.collectionID, schema)
suite.manager.Collection.PutOrRef(suite.collectionID, suite.manager.Collection.PutOrRef(suite.collectionID,
schema, schema,
indexMeta, indexMeta,
@ -99,7 +101,7 @@ func (suite *RetrieveSuite) SetupTest() {
) )
suite.Require().NoError(err) suite.Require().NoError(err)
binlogs, _, err := SaveBinLog(ctx, binlogs, _, err := mock_segcore.SaveBinLog(ctx,
suite.collectionID, suite.collectionID,
suite.partitionID, suite.partitionID,
suite.segmentID, suite.segmentID,
@ -127,7 +129,7 @@ func (suite *RetrieveSuite) SetupTest() {
) )
suite.Require().NoError(err) suite.Require().NoError(err)
insertMsg, err := genInsertMsg(suite.collection, suite.partitionID, suite.growing.ID(), msgLength) insertMsg, err := mock_segcore.GenInsertMsg(suite.collection.GetCCollection(), suite.partitionID, suite.growing.ID(), msgLength)
suite.Require().NoError(err) suite.Require().NoError(err)
insertRecord, err := storage.TransferInsertMsgToInsertRecord(suite.collection.Schema(), insertMsg) insertRecord, err := storage.TransferInsertMsgToInsertRecord(suite.collection.Schema(), insertMsg)
suite.Require().NoError(err) suite.Require().NoError(err)
@ -147,7 +149,7 @@ func (suite *RetrieveSuite) TearDownTest() {
} }
func (suite *RetrieveSuite) TestRetrieveSealed() { func (suite *RetrieveSuite) TestRetrieveSealed() {
plan, err := genSimpleRetrievePlan(suite.collection) plan, err := mock_segcore.GenSimpleRetrievePlan(suite.collection.GetCCollection())
suite.NoError(err) suite.NoError(err)
req := &querypb.QueryRequest{ req := &querypb.QueryRequest{
@ -164,13 +166,16 @@ func (suite *RetrieveSuite) TestRetrieveSealed() {
suite.Len(res[0].Result.Offset, 3) suite.Len(res[0].Result.Offset, 3)
suite.manager.Segment.Unpin(segments) suite.manager.Segment.Unpin(segments)
resultByOffsets, err := suite.sealed.RetrieveByOffsets(context.Background(), plan, []int64{0, 1}) resultByOffsets, err := suite.sealed.RetrieveByOffsets(context.Background(), &segcore.RetrievePlanWithOffsets{
RetrievePlan: plan,
Offsets: []int64{0, 1},
})
suite.NoError(err) suite.NoError(err)
suite.Len(resultByOffsets.Offset, 0) suite.Len(resultByOffsets.Offset, 0)
} }
func (suite *RetrieveSuite) TestRetrieveGrowing() { func (suite *RetrieveSuite) TestRetrieveGrowing() {
plan, err := genSimpleRetrievePlan(suite.collection) plan, err := mock_segcore.GenSimpleRetrievePlan(suite.collection.GetCCollection())
suite.NoError(err) suite.NoError(err)
req := &querypb.QueryRequest{ req := &querypb.QueryRequest{
@ -187,13 +192,16 @@ func (suite *RetrieveSuite) TestRetrieveGrowing() {
suite.Len(res[0].Result.Offset, 3) suite.Len(res[0].Result.Offset, 3)
suite.manager.Segment.Unpin(segments) suite.manager.Segment.Unpin(segments)
resultByOffsets, err := suite.growing.RetrieveByOffsets(context.Background(), plan, []int64{0, 1}) resultByOffsets, err := suite.growing.RetrieveByOffsets(context.Background(), &segcore.RetrievePlanWithOffsets{
RetrievePlan: plan,
Offsets: []int64{0, 1},
})
suite.NoError(err) suite.NoError(err)
suite.Len(resultByOffsets.Offset, 0) suite.Len(resultByOffsets.Offset, 0)
} }
func (suite *RetrieveSuite) TestRetrieveStreamSealed() { func (suite *RetrieveSuite) TestRetrieveStreamSealed() {
plan, err := genSimpleRetrievePlan(suite.collection) plan, err := mock_segcore.GenSimpleRetrievePlan(suite.collection.GetCCollection())
suite.NoError(err) suite.NoError(err)
req := &querypb.QueryRequest{ req := &querypb.QueryRequest{
@ -237,7 +245,7 @@ func (suite *RetrieveSuite) TestRetrieveStreamSealed() {
} }
func (suite *RetrieveSuite) TestRetrieveNonExistSegment() { func (suite *RetrieveSuite) TestRetrieveNonExistSegment() {
plan, err := genSimpleRetrievePlan(suite.collection) plan, err := mock_segcore.GenSimpleRetrievePlan(suite.collection.GetCCollection())
suite.NoError(err) suite.NoError(err)
req := &querypb.QueryRequest{ req := &querypb.QueryRequest{
@ -256,7 +264,7 @@ func (suite *RetrieveSuite) TestRetrieveNonExistSegment() {
} }
func (suite *RetrieveSuite) TestRetrieveNilSegment() { func (suite *RetrieveSuite) TestRetrieveNilSegment() {
plan, err := genSimpleRetrievePlan(suite.collection) plan, err := mock_segcore.GenSimpleRetrievePlan(suite.collection.GetCCollection())
suite.NoError(err) suite.NoError(err)
suite.sealed.Release(context.Background()) suite.sealed.Release(context.Background())

View File

@ -55,7 +55,7 @@ func searchSegments(ctx context.Context, mgr *Manager, segments []Segment, segTy
metrics.QueryNodeSQSegmentLatency.WithLabelValues(fmt.Sprint(paramtable.GetNodeID()), metrics.QueryNodeSQSegmentLatency.WithLabelValues(fmt.Sprint(paramtable.GetNodeID()),
metrics.SearchLabel, searchLabel).Observe(float64(elapsed)) metrics.SearchLabel, searchLabel).Observe(float64(elapsed))
metrics.QueryNodeSegmentSearchLatencyPerVector.WithLabelValues(fmt.Sprint(paramtable.GetNodeID()), metrics.QueryNodeSegmentSearchLatencyPerVector.WithLabelValues(fmt.Sprint(paramtable.GetNodeID()),
metrics.SearchLabel, searchLabel).Observe(float64(elapsed) / float64(searchReq.getNumOfQuery())) metrics.SearchLabel, searchLabel).Observe(float64(elapsed) / float64(searchReq.GetNumOfQuery()))
return nil return nil
} }
@ -64,7 +64,7 @@ func searchSegments(ctx context.Context, mgr *Manager, segments []Segment, segTy
segmentsWithoutIndex := make([]int64, 0) segmentsWithoutIndex := make([]int64, 0)
for _, segment := range segments { for _, segment := range segments {
seg := segment seg := segment
if !seg.ExistIndex(searchReq.searchFieldID) { if !seg.ExistIndex(searchReq.SearchFieldID()) {
segmentsWithoutIndex = append(segmentsWithoutIndex, seg.ID()) segmentsWithoutIndex = append(segmentsWithoutIndex, seg.ID())
} }
errGroup.Go(func() error { errGroup.Go(func() error {
@ -148,7 +148,7 @@ func searchSegmentsStreamly(ctx context.Context,
metrics.QueryNodeSQSegmentLatency.WithLabelValues(fmt.Sprint(paramtable.GetNodeID()), metrics.QueryNodeSQSegmentLatency.WithLabelValues(fmt.Sprint(paramtable.GetNodeID()),
metrics.SearchLabel, searchLabel).Observe(float64(searchDuration)) metrics.SearchLabel, searchLabel).Observe(float64(searchDuration))
metrics.QueryNodeSegmentSearchLatencyPerVector.WithLabelValues(fmt.Sprint(paramtable.GetNodeID()), metrics.QueryNodeSegmentSearchLatencyPerVector.WithLabelValues(fmt.Sprint(paramtable.GetNodeID()),
metrics.SearchLabel, searchLabel).Observe(float64(searchDuration) / float64(searchReq.getNumOfQuery())) metrics.SearchLabel, searchLabel).Observe(float64(searchDuration) / float64(searchReq.GetNumOfQuery()))
return nil return nil
} }

View File

@ -7,6 +7,7 @@ import (
"github.com/stretchr/testify/suite" "github.com/stretchr/testify/suite"
"github.com/milvus-io/milvus-proto/go-api/v2/schemapb" "github.com/milvus-io/milvus-proto/go-api/v2/schemapb"
"github.com/milvus-io/milvus/internal/mocks/util/mock_segcore"
"github.com/milvus-io/milvus/internal/util/reduce" "github.com/milvus-io/milvus/internal/util/reduce"
"github.com/milvus-io/milvus/pkg/util/paramtable" "github.com/milvus-io/milvus/pkg/util/paramtable"
) )
@ -24,8 +25,8 @@ func (suite *SearchReduceSuite) TestResult_ReduceSearchResultData() {
ids := []int64{1, 2, 3, 4} ids := []int64{1, 2, 3, 4}
scores := []float32{-1.0, -2.0, -3.0, -4.0} scores := []float32{-1.0, -2.0, -3.0, -4.0}
topks := []int64{int64(len(ids))} topks := []int64{int64(len(ids))}
data1 := genSearchResultData(nq, topk, ids, scores, topks) data1 := mock_segcore.GenSearchResultData(nq, topk, ids, scores, topks)
data2 := genSearchResultData(nq, topk, ids, scores, topks) data2 := mock_segcore.GenSearchResultData(nq, topk, ids, scores, topks)
dataArray := make([]*schemapb.SearchResultData, 0) dataArray := make([]*schemapb.SearchResultData, 0)
dataArray = append(dataArray, data1) dataArray = append(dataArray, data1)
dataArray = append(dataArray, data2) dataArray = append(dataArray, data2)
@ -43,8 +44,8 @@ func (suite *SearchReduceSuite) TestResult_ReduceSearchResultData() {
ids2 := []int64{5, 1, 3, 4} ids2 := []int64{5, 1, 3, 4}
scores2 := []float32{-1.0, -1.0, -3.0, -4.0} scores2 := []float32{-1.0, -1.0, -3.0, -4.0}
topks2 := []int64{int64(len(ids2))} topks2 := []int64{int64(len(ids2))}
data1 := genSearchResultData(nq, topk, ids1, scores1, topks1) data1 := mock_segcore.GenSearchResultData(nq, topk, ids1, scores1, topks1)
data2 := genSearchResultData(nq, topk, ids2, scores2, topks2) data2 := mock_segcore.GenSearchResultData(nq, topk, ids2, scores2, topks2)
dataArray := make([]*schemapb.SearchResultData, 0) dataArray := make([]*schemapb.SearchResultData, 0)
dataArray = append(dataArray, data1) dataArray = append(dataArray, data1)
dataArray = append(dataArray, data2) dataArray = append(dataArray, data2)
@ -68,8 +69,8 @@ func (suite *SearchReduceSuite) TestResult_SearchGroupByResult() {
ids2 := []int64{5, 1, 3, 4} ids2 := []int64{5, 1, 3, 4}
scores2 := []float32{-1.0, -1.0, -3.0, -4.0} scores2 := []float32{-1.0, -1.0, -3.0, -4.0}
topks2 := []int64{int64(len(ids2))} topks2 := []int64{int64(len(ids2))}
data1 := genSearchResultData(nq, topk, ids1, scores1, topks1) data1 := mock_segcore.GenSearchResultData(nq, topk, ids1, scores1, topks1)
data2 := genSearchResultData(nq, topk, ids2, scores2, topks2) data2 := mock_segcore.GenSearchResultData(nq, topk, ids2, scores2, topks2)
data1.GroupByFieldValue = &schemapb.FieldData{ data1.GroupByFieldValue = &schemapb.FieldData{
Type: schemapb.DataType_Int8, Type: schemapb.DataType_Int8,
Field: &schemapb.FieldData_Scalars{ Field: &schemapb.FieldData_Scalars{
@ -112,8 +113,8 @@ func (suite *SearchReduceSuite) TestResult_SearchGroupByResult() {
ids2 := []int64{3, 4} ids2 := []int64{3, 4}
scores2 := []float32{-1.0, -1.0} scores2 := []float32{-1.0, -1.0}
topks2 := []int64{int64(len(ids2))} topks2 := []int64{int64(len(ids2))}
data1 := genSearchResultData(nq, topk, ids1, scores1, topks1) data1 := mock_segcore.GenSearchResultData(nq, topk, ids1, scores1, topks1)
data2 := genSearchResultData(nq, topk, ids2, scores2, topks2) data2 := mock_segcore.GenSearchResultData(nq, topk, ids2, scores2, topks2)
data1.GroupByFieldValue = &schemapb.FieldData{ data1.GroupByFieldValue = &schemapb.FieldData{
Type: schemapb.DataType_Bool, Type: schemapb.DataType_Bool,
Field: &schemapb.FieldData_Scalars{ Field: &schemapb.FieldData_Scalars{
@ -156,8 +157,8 @@ func (suite *SearchReduceSuite) TestResult_SearchGroupByResult() {
ids2 := []int64{5, 1, 3, 4} ids2 := []int64{5, 1, 3, 4}
scores2 := []float32{-1.0, -1.0, -3.0, -4.0} scores2 := []float32{-1.0, -1.0, -3.0, -4.0}
topks2 := []int64{int64(len(ids2))} topks2 := []int64{int64(len(ids2))}
data1 := genSearchResultData(nq, topk, ids1, scores1, topks1) data1 := mock_segcore.GenSearchResultData(nq, topk, ids1, scores1, topks1)
data2 := genSearchResultData(nq, topk, ids2, scores2, topks2) data2 := mock_segcore.GenSearchResultData(nq, topk, ids2, scores2, topks2)
data1.GroupByFieldValue = &schemapb.FieldData{ data1.GroupByFieldValue = &schemapb.FieldData{
Type: schemapb.DataType_VarChar, Type: schemapb.DataType_VarChar,
Field: &schemapb.FieldData_Scalars{ Field: &schemapb.FieldData_Scalars{
@ -200,8 +201,8 @@ func (suite *SearchReduceSuite) TestResult_SearchGroupByResult() {
ids2 := []int64{4, 5, 6, 7} ids2 := []int64{4, 5, 6, 7}
scores2 := []float32{-1.0, -1.0, -3.0, -4.0} scores2 := []float32{-1.0, -1.0, -3.0, -4.0}
topks2 := []int64{int64(len(ids2))} topks2 := []int64{int64(len(ids2))}
data1 := genSearchResultData(nq, topk, ids1, scores1, topks1) data1 := mock_segcore.GenSearchResultData(nq, topk, ids1, scores1, topks1)
data2 := genSearchResultData(nq, topk, ids2, scores2, topks2) data2 := mock_segcore.GenSearchResultData(nq, topk, ids2, scores2, topks2)
data1.GroupByFieldValue = &schemapb.FieldData{ data1.GroupByFieldValue = &schemapb.FieldData{
Type: schemapb.DataType_VarChar, Type: schemapb.DataType_VarChar,
Field: &schemapb.FieldData_Scalars{ Field: &schemapb.FieldData_Scalars{

View File

@ -24,6 +24,7 @@ import (
"github.com/stretchr/testify/suite" "github.com/stretchr/testify/suite"
"github.com/milvus-io/milvus-proto/go-api/v2/schemapb" "github.com/milvus-io/milvus-proto/go-api/v2/schemapb"
"github.com/milvus-io/milvus/internal/mocks/util/mock_segcore"
"github.com/milvus-io/milvus/internal/proto/datapb" "github.com/milvus-io/milvus/internal/proto/datapb"
"github.com/milvus-io/milvus/internal/proto/querypb" "github.com/milvus-io/milvus/internal/proto/querypb"
storage "github.com/milvus-io/milvus/internal/storage" storage "github.com/milvus-io/milvus/internal/storage"
@ -62,8 +63,8 @@ func (suite *SearchSuite) SetupTest() {
suite.segmentID = 1 suite.segmentID = 1
suite.manager = NewManager() suite.manager = NewManager()
schema := GenTestCollectionSchema("test-reduce", schemapb.DataType_Int64, true) schema := mock_segcore.GenTestCollectionSchema("test-reduce", schemapb.DataType_Int64, true)
indexMeta := GenTestIndexMeta(suite.collectionID, schema) indexMeta := mock_segcore.GenTestIndexMeta(suite.collectionID, schema)
suite.manager.Collection.PutOrRef(suite.collectionID, suite.manager.Collection.PutOrRef(suite.collectionID,
schema, schema,
indexMeta, indexMeta,
@ -90,7 +91,7 @@ func (suite *SearchSuite) SetupTest() {
) )
suite.Require().NoError(err) suite.Require().NoError(err)
binlogs, _, err := SaveBinLog(ctx, binlogs, _, err := mock_segcore.SaveBinLog(ctx,
suite.collectionID, suite.collectionID,
suite.partitionID, suite.partitionID,
suite.segmentID, suite.segmentID,
@ -118,7 +119,7 @@ func (suite *SearchSuite) SetupTest() {
) )
suite.Require().NoError(err) suite.Require().NoError(err)
insertMsg, err := genInsertMsg(suite.collection, suite.partitionID, suite.growing.ID(), msgLength) insertMsg, err := mock_segcore.GenInsertMsg(suite.collection.GetCCollection(), suite.partitionID, suite.growing.ID(), msgLength)
suite.Require().NoError(err) suite.Require().NoError(err)
insertRecord, err := storage.TransferInsertMsgToInsertRecord(suite.collection.Schema(), insertMsg) insertRecord, err := storage.TransferInsertMsgToInsertRecord(suite.collection.Schema(), insertMsg)
suite.Require().NoError(err) suite.Require().NoError(err)
@ -141,7 +142,7 @@ func (suite *SearchSuite) TestSearchSealed() {
ctx, cancel := context.WithCancel(context.Background()) ctx, cancel := context.WithCancel(context.Background())
defer cancel() defer cancel()
searchReq, err := genSearchPlanAndRequests(suite.collection, []int64{suite.sealed.ID()}, IndexFaissIDMap, nq) searchReq, err := mock_segcore.GenSearchPlanAndRequests(suite.collection.GetCCollection(), []int64{suite.sealed.ID()}, mock_segcore.IndexFaissIDMap, nq)
suite.NoError(err) suite.NoError(err)
_, segments, err := SearchHistorical(ctx, suite.manager, searchReq, suite.collectionID, nil, []int64{suite.sealed.ID()}) _, segments, err := SearchHistorical(ctx, suite.manager, searchReq, suite.collectionID, nil, []int64{suite.sealed.ID()})
@ -150,7 +151,7 @@ func (suite *SearchSuite) TestSearchSealed() {
} }
func (suite *SearchSuite) TestSearchGrowing() { func (suite *SearchSuite) TestSearchGrowing() {
searchReq, err := genSearchPlanAndRequests(suite.collection, []int64{suite.growing.ID()}, IndexFaissIDMap, 1) searchReq, err := mock_segcore.GenSearchPlanAndRequests(suite.collection.GetCCollection(), []int64{suite.growing.ID()}, mock_segcore.IndexFaissIDMap, 1)
suite.NoError(err) suite.NoError(err)
res, segments, err := SearchStreaming(context.TODO(), suite.manager, searchReq, res, segments, err := SearchStreaming(context.TODO(), suite.manager, searchReq,

View File

@ -0,0 +1,21 @@
package segments
import "github.com/milvus-io/milvus/internal/util/segcore"
type (
SearchRequest = segcore.SearchRequest
SearchResult = segcore.SearchResult
SearchPlan = segcore.SearchPlan
RetrievePlan = segcore.RetrievePlan
)
func DeleteSearchResults(results []*SearchResult) {
if len(results) == 0 {
return
}
for _, result := range results {
if result != nil {
result.Release()
}
}
}

View File

@ -29,7 +29,6 @@ import "C"
import ( import (
"context" "context"
"fmt" "fmt"
"runtime"
"strings" "strings"
"time" "time"
"unsafe" "unsafe"
@ -52,8 +51,8 @@ import (
"github.com/milvus-io/milvus/internal/querynodev2/pkoracle" "github.com/milvus-io/milvus/internal/querynodev2/pkoracle"
"github.com/milvus-io/milvus/internal/querynodev2/segments/state" "github.com/milvus-io/milvus/internal/querynodev2/segments/state"
"github.com/milvus-io/milvus/internal/storage" "github.com/milvus-io/milvus/internal/storage"
"github.com/milvus-io/milvus/internal/util/cgo"
"github.com/milvus-io/milvus/internal/util/indexparamcheck" "github.com/milvus-io/milvus/internal/util/indexparamcheck"
"github.com/milvus-io/milvus/internal/util/segcore"
"github.com/milvus-io/milvus/internal/util/vecindexmgr" "github.com/milvus-io/milvus/internal/util/vecindexmgr"
"github.com/milvus-io/milvus/pkg/common" "github.com/milvus-io/milvus/pkg/common"
"github.com/milvus-io/milvus/pkg/log" "github.com/milvus-io/milvus/pkg/log"
@ -268,7 +267,9 @@ var _ Segment = (*LocalSegment)(nil)
type LocalSegment struct { type LocalSegment struct {
baseSegment baseSegment
ptrLock *state.LoadStateLock ptrLock *state.LoadStateLock
ptr C.CSegmentInterface ptr C.CSegmentInterface // TODO: Remove in future, after move load index into segcore package.
// always keep same with csegment.RawPtr(), for eaiser to access,
csegment segcore.CSegment
// cached results, to avoid too many CGO calls // cached results, to avoid too many CGO calls
memSize *atomic.Int64 memSize *atomic.Int64
@ -300,39 +301,17 @@ func NewSegment(ctx context.Context,
return nil, err return nil, err
} }
multipleChunkEnable := paramtable.Get().QueryNodeCfg.MultipleChunkedEnable.GetAsBool()
var cSegType C.SegmentType
var locker *state.LoadStateLock var locker *state.LoadStateLock
switch segmentType { switch segmentType {
case SegmentTypeSealed: case SegmentTypeSealed:
if multipleChunkEnable {
cSegType = C.ChunkedSealed
} else {
cSegType = C.Sealed
}
locker = state.NewLoadStateLock(state.LoadStateOnlyMeta) locker = state.NewLoadStateLock(state.LoadStateOnlyMeta)
case SegmentTypeGrowing: case SegmentTypeGrowing:
locker = state.NewLoadStateLock(state.LoadStateDataLoaded) locker = state.NewLoadStateLock(state.LoadStateDataLoaded)
cSegType = C.Growing
default: default:
return nil, fmt.Errorf("illegal segment type %d when create segment %d", segmentType, loadInfo.GetSegmentID()) return nil, fmt.Errorf("illegal segment type %d when create segment %d", segmentType, loadInfo.GetSegmentID())
} }
var newPtr C.CSegmentInterface logger := log.With(
_, err = GetDynamicPool().Submit(func() (any, error) {
status := C.NewSegment(collection.collectionPtr, cSegType, C.int64_t(loadInfo.GetSegmentID()), &newPtr, C.bool(loadInfo.GetIsSorted()))
err := HandleCStatus(ctx, &status, "NewSegmentFailed",
zap.Int64("collectionID", loadInfo.GetCollectionID()),
zap.Int64("partitionID", loadInfo.GetPartitionID()),
zap.Int64("segmentID", loadInfo.GetSegmentID()),
zap.String("segmentType", segmentType.String()))
return nil, err
}).Await()
if err != nil {
return nil, err
}
log.Info("create segment",
zap.Int64("collectionID", loadInfo.GetCollectionID()), zap.Int64("collectionID", loadInfo.GetCollectionID()),
zap.Int64("partitionID", loadInfo.GetPartitionID()), zap.Int64("partitionID", loadInfo.GetPartitionID()),
zap.Int64("segmentID", loadInfo.GetSegmentID()), zap.Int64("segmentID", loadInfo.GetSegmentID()),
@ -340,10 +319,28 @@ func NewSegment(ctx context.Context,
zap.String("level", loadInfo.GetLevel().String()), zap.String("level", loadInfo.GetLevel().String()),
) )
var csegment segcore.CSegment
if _, err := GetDynamicPool().Submit(func() (any, error) {
var err error
csegment, err = segcore.CreateCSegment(&segcore.CreateCSegmentRequest{
Collection: collection.ccollection,
SegmentID: loadInfo.GetSegmentID(),
SegmentType: segmentType,
IsSorted: loadInfo.GetIsSorted(),
EnableChunked: paramtable.Get().QueryNodeCfg.MultipleChunkedEnable.GetAsBool(),
})
return nil, err
}).Await(); err != nil {
logger.Warn("create segment failed", zap.Error(err))
return nil, err
}
log.Info("create segment done")
segment := &LocalSegment{ segment := &LocalSegment{
baseSegment: base, baseSegment: base,
ptrLock: locker, ptrLock: locker,
ptr: newPtr, ptr: C.CSegmentInterface(csegment.RawPointer()),
csegment: csegment,
lastDeltaTimestamp: atomic.NewUint64(0), lastDeltaTimestamp: atomic.NewUint64(0),
fields: typeutil.NewConcurrentMap[int64, *FieldInfo](), fields: typeutil.NewConcurrentMap[int64, *FieldInfo](),
fieldIndexes: typeutil.NewConcurrentMap[int64, *IndexedFieldInfo](), fieldIndexes: typeutil.NewConcurrentMap[int64, *IndexedFieldInfo](),
@ -354,6 +351,7 @@ func NewSegment(ctx context.Context,
} }
if err := segment.initializeSegment(); err != nil { if err := segment.initializeSegment(); err != nil {
csegment.Release()
return nil, err return nil, err
} }
return segment, nil return segment, nil
@ -424,15 +422,12 @@ func (s *LocalSegment) RowNum() int64 {
rowNum := s.rowNum.Load() rowNum := s.rowNum.Load()
if rowNum < 0 { if rowNum < 0 {
var rowCount C.int64_t
GetDynamicPool().Submit(func() (any, error) { GetDynamicPool().Submit(func() (any, error) {
rowCount = C.GetRealCount(s.ptr) rowNum = s.csegment.RowNum()
s.rowNum.Store(int64(rowCount)) s.rowNum.Store(rowNum)
return nil, nil return nil, nil
}).Await() }).Await()
rowNum = int64(rowCount)
} }
return rowNum return rowNum
} }
@ -444,14 +439,11 @@ func (s *LocalSegment) MemSize() int64 {
memSize := s.memSize.Load() memSize := s.memSize.Load()
if memSize < 0 { if memSize < 0 {
var cMemSize C.int64_t
GetDynamicPool().Submit(func() (any, error) { GetDynamicPool().Submit(func() (any, error) {
cMemSize = C.GetMemoryUsageInBytes(s.ptr) memSize = s.csegment.MemSize()
s.memSize.Store(int64(cMemSize)) s.memSize.Store(memSize)
return nil, nil return nil, nil
}).Await() }).Await()
memSize = int64(cMemSize)
} }
return memSize return memSize
} }
@ -479,8 +471,7 @@ func (s *LocalSegment) HasRawData(fieldID int64) bool {
} }
defer s.ptrLock.RUnlock() defer s.ptrLock.RUnlock()
ret := C.HasRawData(s.ptr, C.int64_t(fieldID)) return s.csegment.HasRawData(fieldID)
return bool(ret)
} }
func (s *LocalSegment) Indexes() []*IndexedFieldInfo { func (s *LocalSegment) Indexes() []*IndexedFieldInfo {
@ -498,192 +489,124 @@ func (s *LocalSegment) ResetIndexesLazyLoad(lazyState bool) {
} }
} }
func (s *LocalSegment) Search(ctx context.Context, searchReq *SearchRequest) (*SearchResult, error) { func (s *LocalSegment) Search(ctx context.Context, searchReq *segcore.SearchRequest) (*segcore.SearchResult, error) {
/* log := log.Ctx(ctx).WithLazy(
CStatus
Search(void* plan,
void* placeholder_groups,
uint64_t* timestamps,
int num_groups,
long int* result_ids,
float* result_distances);
*/
log := log.Ctx(ctx).With(
zap.Int64("collectionID", s.Collection()), zap.Int64("collectionID", s.Collection()),
zap.Int64("segmentID", s.ID()), zap.Int64("segmentID", s.ID()),
zap.String("segmentType", s.segmentType.String()), zap.String("segmentType", s.segmentType.String()),
) )
if !s.ptrLock.RLockIf(state.IsNotReleased) { if !s.ptrLock.RLockIf(state.IsNotReleased) {
// TODO: check if the segment is readable but not released. too many related logic need to be refactor. // TODO: check if the segment is readable but not released. too many related logic need to be refactor.
return nil, merr.WrapErrSegmentNotLoaded(s.ID(), "segment released") return nil, merr.WrapErrSegmentNotLoaded(s.ID(), "segment released")
} }
defer s.ptrLock.RUnlock() defer s.ptrLock.RUnlock()
traceCtx := ParseCTraceContext(ctx) hasIndex := s.ExistIndex(searchReq.SearchFieldID())
defer runtime.KeepAlive(traceCtx)
defer runtime.KeepAlive(searchReq)
hasIndex := s.ExistIndex(searchReq.searchFieldID)
log = log.With(zap.Bool("withIndex", hasIndex)) log = log.With(zap.Bool("withIndex", hasIndex))
log.Debug("search segment...") log.Debug("search segment...")
tr := timerecord.NewTimeRecorder("cgoSearch") tr := timerecord.NewTimeRecorder("cgoSearch")
result, err := s.csegment.Search(ctx, searchReq)
future := cgo.Async(
ctx,
func() cgo.CFuturePtr {
return (cgo.CFuturePtr)(C.AsyncSearch(
traceCtx.ctx,
s.ptr,
searchReq.plan.cSearchPlan,
searchReq.cPlaceholderGroup,
C.uint64_t(searchReq.mvccTimestamp),
))
},
cgo.WithName("search"),
)
defer future.Release()
result, err := future.BlockAndLeakyGet()
if err != nil { if err != nil {
log.Warn("Search failed") log.Warn("Search failed")
return nil, err return nil, err
} }
metrics.QueryNodeSQSegmentLatencyInCore.WithLabelValues(fmt.Sprint(paramtable.GetNodeID()), metrics.SearchLabel).Observe(float64(tr.ElapseSpan().Milliseconds())) metrics.QueryNodeSQSegmentLatencyInCore.WithLabelValues(fmt.Sprint(paramtable.GetNodeID()), metrics.SearchLabel).Observe(float64(tr.ElapseSpan().Milliseconds()))
log.Debug("search segment done") log.Debug("search segment done")
return &SearchResult{ return result, nil
cSearchResult: (C.CSearchResult)(result),
}, nil
} }
func (s *LocalSegment) Retrieve(ctx context.Context, plan *RetrievePlan) (*segcorepb.RetrieveResults, error) { func (s *LocalSegment) retrieve(ctx context.Context, plan *segcore.RetrievePlan, log *zap.Logger) (*segcore.RetrieveResult, error) {
if !s.ptrLock.RLockIf(state.IsNotReleased) { if !s.ptrLock.RLockIf(state.IsNotReleased) {
// TODO: check if the segment is readable but not released. too many related logic need to be refactor. // TODO: check if the segment is readable but not released. too many related logic need to be refactor.
return nil, merr.WrapErrSegmentNotLoaded(s.ID(), "segment released") return nil, merr.WrapErrSegmentNotLoaded(s.ID(), "segment released")
} }
defer s.ptrLock.RUnlock() defer s.ptrLock.RUnlock()
log := log.Ctx(ctx).With(
zap.Int64("collectionID", s.Collection()),
zap.Int64("partitionID", s.Partition()),
zap.Int64("segmentID", s.ID()),
zap.Int64("msgID", plan.msgID),
zap.String("segmentType", s.segmentType.String()),
)
log.Debug("begin to retrieve") log.Debug("begin to retrieve")
traceCtx := ParseCTraceContext(ctx)
defer runtime.KeepAlive(traceCtx)
defer runtime.KeepAlive(plan)
maxLimitSize := paramtable.Get().QuotaConfig.MaxOutputSize.GetAsInt64()
tr := timerecord.NewTimeRecorder("cgoRetrieve") tr := timerecord.NewTimeRecorder("cgoRetrieve")
result, err := s.csegment.Retrieve(ctx, plan)
future := cgo.Async(
ctx,
func() cgo.CFuturePtr {
return (cgo.CFuturePtr)(C.AsyncRetrieve(
traceCtx.ctx,
s.ptr,
plan.cRetrievePlan,
C.uint64_t(plan.Timestamp),
C.int64_t(maxLimitSize),
C.bool(plan.ignoreNonPk),
))
},
cgo.WithName("retrieve"),
)
defer future.Release()
result, err := future.BlockAndLeakyGet()
if err != nil { if err != nil {
log.Warn("Retrieve failed") log.Warn("Retrieve failed")
return nil, err return nil, err
} }
defer C.DeleteRetrieveResult((*C.CRetrieveResult)(result))
metrics.QueryNodeSQSegmentLatencyInCore.WithLabelValues(fmt.Sprint(paramtable.GetNodeID()), metrics.QueryNodeSQSegmentLatencyInCore.WithLabelValues(fmt.Sprint(paramtable.GetNodeID()),
metrics.QueryLabel).Observe(float64(tr.ElapseSpan().Milliseconds())) metrics.QueryLabel).Observe(float64(tr.ElapseSpan().Milliseconds()))
return result, nil
}
func (s *LocalSegment) Retrieve(ctx context.Context, plan *segcore.RetrievePlan) (*segcorepb.RetrieveResults, error) {
log := log.Ctx(ctx).WithLazy(
zap.Int64("collectionID", s.Collection()),
zap.Int64("partitionID", s.Partition()),
zap.Int64("segmentID", s.ID()),
zap.Int64("msgID", plan.MsgID()),
zap.String("segmentType", s.segmentType.String()),
)
result, err := s.retrieve(ctx, plan, log)
if err != nil {
return nil, err
}
defer result.Release()
_, span := otel.Tracer(typeutil.QueryNodeRole).Start(ctx, "partial-segcore-results-deserialization") _, span := otel.Tracer(typeutil.QueryNodeRole).Start(ctx, "partial-segcore-results-deserialization")
defer span.End() defer span.End()
retrieveResult := new(segcorepb.RetrieveResults) retrieveResult, err := result.GetResult()
if err := UnmarshalCProto((*C.CRetrieveResult)(result), retrieveResult); err != nil { if err != nil {
log.Warn("unmarshal retrieve result failed", zap.Error(err)) log.Warn("unmarshal retrieve result failed", zap.Error(err))
return nil, err return nil, err
} }
log.Debug("retrieve segment done", zap.Int("resultNum", len(retrieveResult.Offset)))
log.Debug("retrieve segment done",
zap.Int("resultNum", len(retrieveResult.Offset)),
)
// Sort was done by the segcore.
// sort.Sort(&byPK{result})
return retrieveResult, nil return retrieveResult, nil
} }
func (s *LocalSegment) RetrieveByOffsets(ctx context.Context, plan *RetrievePlan, offsets []int64) (*segcorepb.RetrieveResults, error) { func (s *LocalSegment) retrieveByOffsets(ctx context.Context, plan *segcore.RetrievePlanWithOffsets, log *zap.Logger) (*segcore.RetrieveResult, error) {
if len(offsets) == 0 {
return nil, merr.WrapErrParameterInvalid("segment offsets", "empty offsets")
}
if !s.ptrLock.RLockIf(state.IsNotReleased) { if !s.ptrLock.RLockIf(state.IsNotReleased) {
// TODO: check if the segment is readable but not released. too many related logic need to be refactor. // TODO: check if the segment is readable but not released. too many related logic need to be refactor.
return nil, merr.WrapErrSegmentNotLoaded(s.ID(), "segment released") return nil, merr.WrapErrSegmentNotLoaded(s.ID(), "segment released")
} }
defer s.ptrLock.RUnlock() defer s.ptrLock.RUnlock()
fields := []zap.Field{
zap.Int64("collectionID", s.Collection()),
zap.Int64("partitionID", s.Partition()),
zap.Int64("segmentID", s.ID()),
zap.Int64("msgID", plan.msgID),
zap.String("segmentType", s.segmentType.String()),
zap.Int("resultNum", len(offsets)),
}
log := log.Ctx(ctx).With(fields...)
log.Debug("begin to retrieve by offsets") log.Debug("begin to retrieve by offsets")
tr := timerecord.NewTimeRecorder("cgoRetrieveByOffsets") tr := timerecord.NewTimeRecorder("cgoRetrieveByOffsets")
traceCtx := ParseCTraceContext(ctx) result, err := s.csegment.RetrieveByOffsets(ctx, plan)
defer runtime.KeepAlive(traceCtx)
defer runtime.KeepAlive(plan)
defer runtime.KeepAlive(offsets)
future := cgo.Async(
ctx,
func() cgo.CFuturePtr {
return (cgo.CFuturePtr)(C.AsyncRetrieveByOffsets(
traceCtx.ctx,
s.ptr,
plan.cRetrievePlan,
(*C.int64_t)(unsafe.Pointer(&offsets[0])),
C.int64_t(len(offsets)),
))
},
cgo.WithName("retrieve-by-offsets"),
)
defer future.Release()
result, err := future.BlockAndLeakyGet()
if err != nil { if err != nil {
log.Warn("RetrieveByOffsets failed") log.Warn("RetrieveByOffsets failed")
return nil, err return nil, err
} }
defer C.DeleteRetrieveResult((*C.CRetrieveResult)(result))
metrics.QueryNodeSQSegmentLatencyInCore.WithLabelValues(fmt.Sprint(paramtable.GetNodeID()), metrics.QueryNodeSQSegmentLatencyInCore.WithLabelValues(fmt.Sprint(paramtable.GetNodeID()),
metrics.QueryLabel).Observe(float64(tr.ElapseSpan().Milliseconds())) metrics.QueryLabel).Observe(float64(tr.ElapseSpan().Milliseconds()))
return result, nil
}
func (s *LocalSegment) RetrieveByOffsets(ctx context.Context, plan *segcore.RetrievePlanWithOffsets) (*segcorepb.RetrieveResults, error) {
log := log.Ctx(ctx).WithLazy(zap.Int64("collectionID", s.Collection()),
zap.Int64("partitionID", s.Partition()),
zap.Int64("segmentID", s.ID()),
zap.Int64("msgID", plan.MsgID()),
zap.String("segmentType", s.segmentType.String()),
zap.Int("resultNum", len(plan.Offsets)),
)
result, err := s.retrieveByOffsets(ctx, plan, log)
if err != nil {
return nil, err
}
defer result.Release()
_, span := otel.Tracer(typeutil.QueryNodeRole).Start(ctx, "reduced-segcore-results-deserialization") _, span := otel.Tracer(typeutil.QueryNodeRole).Start(ctx, "reduced-segcore-results-deserialization")
defer span.End() defer span.End()
retrieveResult := new(segcorepb.RetrieveResults) retrieveResult, err := result.GetResult()
if err := UnmarshalCProto((*C.CRetrieveResult)(result), retrieveResult); err != nil { if err != nil {
log.Warn("unmarshal retrieve by offsets result failed", zap.Error(err)) log.Warn("unmarshal retrieve by offsets result failed", zap.Error(err))
return nil, err return nil, err
} }
log.Debug("retrieve by segment offsets done", zap.Int("resultNum", len(retrieveResult.Offset)))
log.Debug("retrieve by segment offsets done",
zap.Int("resultNum", len(retrieveResult.Offset)),
)
return retrieveResult, nil return retrieveResult, nil
} }
@ -700,26 +623,6 @@ func (s *LocalSegment) GetFieldDataPath(index *IndexedFieldInfo, offset int64) (
return dataPath, offsetInBinlog return dataPath, offsetInBinlog
} }
// -------------------------------------------------------------------------------------- interfaces for growing segment
func (s *LocalSegment) preInsert(ctx context.Context, numOfRecords int) (int64, error) {
/*
long int
PreInsert(CSegmentInterface c_segment, long int size);
*/
var offset int64
cOffset := (*C.int64_t)(&offset)
var status C.CStatus
GetDynamicPool().Submit(func() (any, error) {
status = C.PreInsert(s.ptr, C.int64_t(int64(numOfRecords)), cOffset)
return nil, nil
}).Await()
if err := HandleCStatus(ctx, &status, "PreInsert failed"); err != nil {
return 0, err
}
return offset, nil
}
func (s *LocalSegment) Insert(ctx context.Context, rowIDs []int64, timestamps []typeutil.Timestamp, record *segcorepb.InsertRecord) error { func (s *LocalSegment) Insert(ctx context.Context, rowIDs []int64, timestamps []typeutil.Timestamp, record *segcorepb.InsertRecord) error {
if s.Type() != SegmentTypeGrowing { if s.Type() != SegmentTypeGrowing {
return fmt.Errorf("unexpected segmentType when segmentInsert, segmentType = %s", s.segmentType.String()) return fmt.Errorf("unexpected segmentType when segmentInsert, segmentType = %s", s.segmentType.String())
@ -729,24 +632,8 @@ func (s *LocalSegment) Insert(ctx context.Context, rowIDs []int64, timestamps []
} }
defer s.ptrLock.RUnlock() defer s.ptrLock.RUnlock()
offset, err := s.preInsert(ctx, len(rowIDs)) var result *segcore.InsertResult
if err != nil { var err error
return err
}
insertRecordBlob, err := proto.Marshal(record)
if err != nil {
return fmt.Errorf("failed to marshal insert record: %s", err)
}
numOfRow := len(rowIDs)
cOffset := C.int64_t(offset)
cNumOfRows := C.int64_t(numOfRow)
cEntityIDsPtr := (*C.int64_t)(&(rowIDs)[0])
cTimestampsPtr := (*C.uint64_t)(&(timestamps)[0])
var status C.CStatus
GetDynamicPool().Submit(func() (any, error) { GetDynamicPool().Submit(func() (any, error) {
start := time.Now() start := time.Now()
defer func() { defer func() {
@ -756,21 +643,19 @@ func (s *LocalSegment) Insert(ctx context.Context, rowIDs []int64, timestamps []
"Sync", "Sync",
).Observe(float64(time.Since(start).Milliseconds())) ).Observe(float64(time.Since(start).Milliseconds()))
}() }()
status = C.Insert(s.ptr,
cOffset, result, err = s.csegment.Insert(ctx, &segcore.InsertRequest{
cNumOfRows, RowIDs: rowIDs,
cEntityIDsPtr, Timestamps: timestamps,
cTimestampsPtr, Record: record,
(*C.uint8_t)(unsafe.Pointer(&insertRecordBlob[0])), })
(C.uint64_t)(len(insertRecordBlob)),
)
return nil, nil return nil, nil
}).Await() }).Await()
if err := HandleCStatus(ctx, &status, "Insert failed"); err != nil {
if err != nil {
return err return err
} }
s.insertCount.Add(result.InsertedRows)
s.insertCount.Add(int64(numOfRow))
s.rowNum.Store(-1) s.rowNum.Store(-1)
s.memSize.Store(-1) s.memSize.Store(-1)
return nil return nil
@ -794,20 +679,7 @@ func (s *LocalSegment) Delete(ctx context.Context, primaryKeys storage.PrimaryKe
} }
defer s.ptrLock.RUnlock() defer s.ptrLock.RUnlock()
cOffset := C.int64_t(0) // depre var err error
cSize := C.int64_t(primaryKeys.Len())
cTimestampsPtr := (*C.uint64_t)(&(timestamps)[0])
ids, err := storage.ParsePrimaryKeysBatch2IDs(primaryKeys)
if err != nil {
return err
}
dataBlob, err := proto.Marshal(ids)
if err != nil {
return fmt.Errorf("failed to marshal ids: %s", err)
}
var status C.CStatus
GetDynamicPool().Submit(func() (any, error) { GetDynamicPool().Submit(func() (any, error) {
start := time.Now() start := time.Now()
defer func() { defer func() {
@ -817,23 +689,19 @@ func (s *LocalSegment) Delete(ctx context.Context, primaryKeys storage.PrimaryKe
"Sync", "Sync",
).Observe(float64(time.Since(start).Milliseconds())) ).Observe(float64(time.Since(start).Milliseconds()))
}() }()
status = C.Delete(s.ptr, _, err = s.csegment.Delete(ctx, &segcore.DeleteRequest{
cOffset, PrimaryKeys: primaryKeys,
cSize, Timestamps: timestamps,
(*C.uint8_t)(unsafe.Pointer(&dataBlob[0])), })
(C.uint64_t)(len(dataBlob)),
cTimestampsPtr,
)
return nil, nil return nil, nil
}).Await() }).Await()
if err := HandleCStatus(ctx, &status, "Delete failed"); err != nil { if err != nil {
return err return err
} }
s.rowNum.Store(-1) s.rowNum.Store(-1)
s.lastDeltaTimestamp.Store(timestamps[len(timestamps)-1]) s.lastDeltaTimestamp.Store(timestamps[len(timestamps)-1])
return nil return nil
} }
@ -854,30 +722,17 @@ func (s *LocalSegment) LoadMultiFieldData(ctx context.Context) error {
zap.Int64("segmentID", s.ID()), zap.Int64("segmentID", s.ID()),
) )
loadFieldDataInfo, err := newLoadFieldDataInfo(ctx) req := &segcore.LoadFieldDataRequest{
defer deleteFieldDataInfo(loadFieldDataInfo) MMapDir: paramtable.Get().QueryNodeCfg.MmapDirPath.GetValue(),
if err != nil { RowCount: rowCount,
return err
} }
for _, field := range fields { for _, field := range fields {
fieldID := field.FieldID req.Fields = append(req.Fields, segcore.LoadFieldDataInfo{
err = loadFieldDataInfo.appendLoadFieldInfo(ctx, fieldID, rowCount) Field: field,
if err != nil { })
return err
}
for _, binlog := range field.Binlogs {
err = loadFieldDataInfo.appendLoadFieldDataPath(ctx, fieldID, binlog)
if err != nil {
return err
}
}
loadFieldDataInfo.appendMMapDirPath(paramtable.Get().QueryNodeCfg.MmapDirPath.GetValue())
} }
var status C.CStatus var err error
GetLoadPool().Submit(func() (any, error) { GetLoadPool().Submit(func() (any, error) {
start := time.Now() start := time.Now()
defer func() { defer func() {
@ -887,20 +742,15 @@ func (s *LocalSegment) LoadMultiFieldData(ctx context.Context) error {
"Sync", "Sync",
).Observe(float64(time.Since(start).Milliseconds())) ).Observe(float64(time.Since(start).Milliseconds()))
}() }()
status = C.LoadFieldData(s.ptr, loadFieldDataInfo.cLoadFieldDataInfo) _, err = s.csegment.LoadFieldData(ctx, req)
return nil, nil return nil, nil
}).Await() }).Await()
if err := HandleCStatus(ctx, &status, "LoadMultiFieldData failed", if err != nil {
zap.Int64("collectionID", s.Collection()), log.Warn("LoadMultiFieldData failed", zap.Error(err))
zap.Int64("partitionID", s.Partition()),
zap.Int64("segmentID", s.ID())); err != nil {
return err return err
} }
log.Info("load mutil field done", log.Info("load mutil field done", zap.Int64("row count", rowCount), zap.Int64("segmentID", s.ID()))
zap.Int64("row count", rowCount),
zap.Int64("segmentID", s.ID()))
return nil return nil
} }
@ -922,26 +772,6 @@ func (s *LocalSegment) LoadFieldData(ctx context.Context, fieldID int64, rowCoun
) )
log.Info("start loading field data for field") log.Info("start loading field data for field")
loadFieldDataInfo, err := newLoadFieldDataInfo(ctx)
if err != nil {
return err
}
defer deleteFieldDataInfo(loadFieldDataInfo)
err = loadFieldDataInfo.appendLoadFieldInfo(ctx, fieldID, rowCount)
if err != nil {
return err
}
if field != nil {
for _, binlog := range field.Binlogs {
err = loadFieldDataInfo.appendLoadFieldDataPath(ctx, fieldID, binlog)
if err != nil {
return err
}
}
}
// TODO retrieve_enable should be considered // TODO retrieve_enable should be considered
collection := s.collection collection := s.collection
fieldSchema, err := getFieldSchema(collection.Schema(), fieldID) fieldSchema, err := getFieldSchema(collection.Schema(), fieldID)
@ -949,10 +779,15 @@ func (s *LocalSegment) LoadFieldData(ctx context.Context, fieldID int64, rowCoun
return err return err
} }
mmapEnabled := isDataMmapEnable(fieldSchema) mmapEnabled := isDataMmapEnable(fieldSchema)
loadFieldDataInfo.appendMMapDirPath(paramtable.Get().QueryNodeCfg.MmapDirPath.GetValue()) req := &segcore.LoadFieldDataRequest{
loadFieldDataInfo.enableMmap(fieldID, mmapEnabled) MMapDir: paramtable.Get().QueryNodeCfg.MmapDirPath.GetValue(),
Fields: []segcore.LoadFieldDataInfo{{
Field: field,
EnableMMap: mmapEnabled,
}},
RowCount: rowCount,
}
var status C.CStatus
GetLoadPool().Submit(func() (any, error) { GetLoadPool().Submit(func() (any, error) {
start := time.Now() start := time.Now()
defer func() { defer func() {
@ -962,20 +797,16 @@ func (s *LocalSegment) LoadFieldData(ctx context.Context, fieldID int64, rowCoun
"Sync", "Sync",
).Observe(float64(time.Since(start).Milliseconds())) ).Observe(float64(time.Since(start).Milliseconds()))
}() }()
_, err = s.csegment.LoadFieldData(ctx, req)
log.Info("submitted loadFieldData task to load pool") log.Info("submitted loadFieldData task to load pool")
status = C.LoadFieldData(s.ptr, loadFieldDataInfo.cLoadFieldDataInfo)
return nil, nil return nil, nil
}).Await() }).Await()
if err := HandleCStatus(ctx, &status, "LoadFieldData failed",
zap.Int64("collectionID", s.Collection()), if err != nil {
zap.Int64("partitionID", s.Partition()), log.Warn("LoadFieldData failed", zap.Error(err))
zap.Int64("segmentID", s.ID()),
zap.Int64("fieldID", fieldID)); err != nil {
return err return err
} }
log.Info("load field done") log.Info("load field done")
return nil return nil
} }
@ -985,46 +816,33 @@ func (s *LocalSegment) AddFieldDataInfo(ctx context.Context, rowCount int64, fie
} }
defer s.ptrLock.RUnlock() defer s.ptrLock.RUnlock()
log := log.Ctx(ctx).With( log := log.Ctx(ctx).WithLazy(
zap.Int64("collectionID", s.Collection()), zap.Int64("collectionID", s.Collection()),
zap.Int64("partitionID", s.Partition()), zap.Int64("partitionID", s.Partition()),
zap.Int64("segmentID", s.ID()), zap.Int64("segmentID", s.ID()),
zap.Int64("row count", rowCount), zap.Int64("row count", rowCount),
) )
loadFieldDataInfo, err := newLoadFieldDataInfo(ctx) req := &segcore.AddFieldDataInfoRequest{
if err != nil { Fields: make([]segcore.LoadFieldDataInfo, 0, len(fields)),
return err RowCount: rowCount,
} }
defer deleteFieldDataInfo(loadFieldDataInfo)
for _, field := range fields { for _, field := range fields {
fieldID := field.FieldID req.Fields = append(req.Fields, segcore.LoadFieldDataInfo{
err = loadFieldDataInfo.appendLoadFieldInfo(ctx, fieldID, rowCount) Field: field,
if err != nil { })
return err
}
for _, binlog := range field.Binlogs {
err = loadFieldDataInfo.appendLoadFieldDataPath(ctx, fieldID, binlog)
if err != nil {
return err
}
}
} }
var status C.CStatus var err error
GetLoadPool().Submit(func() (any, error) { GetLoadPool().Submit(func() (any, error) {
status = C.AddFieldDataInfoForSealed(s.ptr, loadFieldDataInfo.cLoadFieldDataInfo) _, err = s.csegment.AddFieldDataInfo(ctx, req)
return nil, nil return nil, nil
}).Await() }).Await()
if err := HandleCStatus(ctx, &status, "AddFieldDataInfo failed",
zap.Int64("collectionID", s.Collection()), if err != nil {
zap.Int64("partitionID", s.Partition()), log.Warn("AddFieldDataInfo failed", zap.Error(err))
zap.Int64("segmentID", s.ID())); err != nil {
return err return err
} }
log.Info("add field data info done") log.Info("add field data info done")
return nil return nil
} }
@ -1456,7 +1274,7 @@ func (s *LocalSegment) Release(ctx context.Context, opts ...releaseOption) {
C.DeleteSegment(ptr) C.DeleteSegment(ptr)
localDiskUsage, err := GetLocalUsedSize(context.Background(), paramtable.Get().LocalStorageCfg.Path.GetValue()) localDiskUsage, err := segcore.GetLocalUsedSize(context.Background(), paramtable.Get().LocalStorageCfg.Path.GetValue())
// ignore error here, shall not block releasing // ignore error here, shall not block releasing
if err == nil { if err == nil {
metrics.QueryNodeDiskUsedSize.WithLabelValues(fmt.Sprint(paramtable.GetNodeID())).Set(float64(localDiskUsage) / 1024 / 1024) // in MB metrics.QueryNodeDiskUsedSize.WithLabelValues(fmt.Sprint(paramtable.GetNodeID())).Set(float64(localDiskUsage) / 1024 / 1024) // in MB

View File

@ -24,6 +24,7 @@ import (
"github.com/milvus-io/milvus/internal/proto/querypb" "github.com/milvus-io/milvus/internal/proto/querypb"
"github.com/milvus-io/milvus/internal/proto/segcorepb" "github.com/milvus-io/milvus/internal/proto/segcorepb"
"github.com/milvus-io/milvus/internal/storage" "github.com/milvus-io/milvus/internal/storage"
"github.com/milvus-io/milvus/internal/util/segcore"
"github.com/milvus-io/milvus/pkg/util/metautil" "github.com/milvus-io/milvus/pkg/util/metautil"
"github.com/milvus-io/milvus/pkg/util/typeutil" "github.com/milvus-io/milvus/pkg/util/typeutil"
) )
@ -92,9 +93,9 @@ type Segment interface {
GetBM25Stats() map[int64]*storage.BM25Stats GetBM25Stats() map[int64]*storage.BM25Stats
// Read operations // Read operations
Search(ctx context.Context, searchReq *SearchRequest) (*SearchResult, error) Search(ctx context.Context, searchReq *segcore.SearchRequest) (*segcore.SearchResult, error)
Retrieve(ctx context.Context, plan *RetrievePlan) (*segcorepb.RetrieveResults, error) Retrieve(ctx context.Context, plan *segcore.RetrievePlan) (*segcorepb.RetrieveResults, error)
RetrieveByOffsets(ctx context.Context, plan *RetrievePlan, offsets []int64) (*segcorepb.RetrieveResults, error) RetrieveByOffsets(ctx context.Context, plan *segcore.RetrievePlanWithOffsets) (*segcorepb.RetrieveResults, error)
IsLazyLoad() bool IsLazyLoad() bool
ResetIndexesLazyLoad(lazyState bool) ResetIndexesLazyLoad(lazyState bool)

View File

@ -27,6 +27,7 @@ import (
"github.com/milvus-io/milvus/internal/proto/querypb" "github.com/milvus-io/milvus/internal/proto/querypb"
"github.com/milvus-io/milvus/internal/proto/segcorepb" "github.com/milvus-io/milvus/internal/proto/segcorepb"
storage "github.com/milvus-io/milvus/internal/storage" storage "github.com/milvus-io/milvus/internal/storage"
"github.com/milvus-io/milvus/internal/util/segcore"
"github.com/milvus-io/milvus/pkg/log" "github.com/milvus-io/milvus/pkg/log"
"github.com/milvus-io/milvus/pkg/util/merr" "github.com/milvus-io/milvus/pkg/util/merr"
"github.com/milvus-io/milvus/pkg/util/typeutil" "github.com/milvus-io/milvus/pkg/util/typeutil"
@ -131,15 +132,15 @@ func (s *L0Segment) Level() datapb.SegmentLevel {
return datapb.SegmentLevel_L0 return datapb.SegmentLevel_L0
} }
func (s *L0Segment) Search(ctx context.Context, searchReq *SearchRequest) (*SearchResult, error) { func (s *L0Segment) Search(ctx context.Context, searchReq *segcore.SearchRequest) (*segcore.SearchResult, error) {
return nil, nil return nil, nil
} }
func (s *L0Segment) Retrieve(ctx context.Context, plan *RetrievePlan) (*segcorepb.RetrieveResults, error) { func (s *L0Segment) Retrieve(ctx context.Context, plan *segcore.RetrievePlan) (*segcorepb.RetrieveResults, error) {
return nil, nil return nil, nil
} }
func (s *L0Segment) RetrieveByOffsets(ctx context.Context, plan *RetrievePlan, offsets []int64) (*segcorepb.RetrieveResults, error) { func (s *L0Segment) RetrieveByOffsets(ctx context.Context, plan *segcore.RetrievePlanWithOffsets) (*segcorepb.RetrieveResults, error) {
return nil, nil return nil, nil
} }

View File

@ -46,6 +46,7 @@ import (
"github.com/milvus-io/milvus/internal/proto/querypb" "github.com/milvus-io/milvus/internal/proto/querypb"
"github.com/milvus-io/milvus/internal/querynodev2/pkoracle" "github.com/milvus-io/milvus/internal/querynodev2/pkoracle"
"github.com/milvus-io/milvus/internal/storage" "github.com/milvus-io/milvus/internal/storage"
"github.com/milvus-io/milvus/internal/util/segcore"
"github.com/milvus-io/milvus/pkg/common" "github.com/milvus-io/milvus/pkg/common"
"github.com/milvus-io/milvus/pkg/log" "github.com/milvus-io/milvus/pkg/log"
"github.com/milvus-io/milvus/pkg/metrics" "github.com/milvus-io/milvus/pkg/metrics"
@ -452,7 +453,7 @@ func (loader *segmentLoader) requestResource(ctx context.Context, infos ...*quer
memoryUsage := hardware.GetUsedMemoryCount() memoryUsage := hardware.GetUsedMemoryCount()
totalMemory := hardware.GetMemoryCount() totalMemory := hardware.GetMemoryCount()
diskUsage, err := GetLocalUsedSize(ctx, paramtable.Get().LocalStorageCfg.Path.GetValue()) diskUsage, err := segcore.GetLocalUsedSize(ctx, paramtable.Get().LocalStorageCfg.Path.GetValue())
if err != nil { if err != nil {
return result, errors.Wrap(err, "get local used size failed") return result, errors.Wrap(err, "get local used size failed")
} }
@ -1365,7 +1366,7 @@ func (loader *segmentLoader) checkSegmentSize(ctx context.Context, segmentLoadIn
return 0, 0, errors.New("get memory failed when checkSegmentSize") return 0, 0, errors.New("get memory failed when checkSegmentSize")
} }
localDiskUsage, err := GetLocalUsedSize(ctx, paramtable.Get().LocalStorageCfg.Path.GetValue()) localDiskUsage, err := segcore.GetLocalUsedSize(ctx, paramtable.Get().LocalStorageCfg.Path.GetValue())
if err != nil { if err != nil {
return 0, 0, errors.Wrap(err, "get local used size failed") return 0, 0, errors.Wrap(err, "get local used size failed")
} }

View File

@ -30,6 +30,7 @@ import (
"github.com/milvus-io/milvus-proto/go-api/v2/commonpb" "github.com/milvus-io/milvus-proto/go-api/v2/commonpb"
"github.com/milvus-io/milvus-proto/go-api/v2/schemapb" "github.com/milvus-io/milvus-proto/go-api/v2/schemapb"
"github.com/milvus-io/milvus/internal/mocks/util/mock_segcore"
"github.com/milvus-io/milvus/internal/proto/datapb" "github.com/milvus-io/milvus/internal/proto/datapb"
"github.com/milvus-io/milvus/internal/proto/querypb" "github.com/milvus-io/milvus/internal/proto/querypb"
"github.com/milvus-io/milvus/internal/storage" "github.com/milvus-io/milvus/internal/storage"
@ -84,8 +85,8 @@ func (suite *SegmentLoaderSuite) SetupTest() {
initcore.InitRemoteChunkManager(paramtable.Get()) initcore.InitRemoteChunkManager(paramtable.Get())
// Data // Data
suite.schema = GenTestCollectionSchema("test", schemapb.DataType_Int64, false) suite.schema = mock_segcore.GenTestCollectionSchema("test", schemapb.DataType_Int64, false)
indexMeta := GenTestIndexMeta(suite.collectionID, suite.schema) indexMeta := mock_segcore.GenTestIndexMeta(suite.collectionID, suite.schema)
loadMeta := &querypb.LoadMetaInfo{ loadMeta := &querypb.LoadMetaInfo{
LoadType: querypb.LoadType_LoadCollection, LoadType: querypb.LoadType_LoadCollection,
CollectionID: suite.collectionID, CollectionID: suite.collectionID,
@ -100,8 +101,8 @@ func (suite *SegmentLoaderSuite) SetupBM25() {
suite.loader = NewLoader(suite.manager, suite.chunkManager) suite.loader = NewLoader(suite.manager, suite.chunkManager)
initcore.InitRemoteChunkManager(paramtable.Get()) initcore.InitRemoteChunkManager(paramtable.Get())
suite.schema = GenTestBM25CollectionSchema("test") suite.schema = mock_segcore.GenTestBM25CollectionSchema("test")
indexMeta := GenTestIndexMeta(suite.collectionID, suite.schema) indexMeta := mock_segcore.GenTestIndexMeta(suite.collectionID, suite.schema)
loadMeta := &querypb.LoadMetaInfo{ loadMeta := &querypb.LoadMetaInfo{
LoadType: querypb.LoadType_LoadCollection, LoadType: querypb.LoadType_LoadCollection,
CollectionID: suite.collectionID, CollectionID: suite.collectionID,
@ -124,7 +125,7 @@ func (suite *SegmentLoaderSuite) TestLoad() {
msgLength := 4 msgLength := 4
// Load sealed // Load sealed
binlogs, statsLogs, err := SaveBinLog(ctx, binlogs, statsLogs, err := mock_segcore.SaveBinLog(ctx,
suite.collectionID, suite.collectionID,
suite.partitionID, suite.partitionID,
suite.segmentID, suite.segmentID,
@ -146,7 +147,7 @@ func (suite *SegmentLoaderSuite) TestLoad() {
suite.NoError(err) suite.NoError(err)
// Load growing // Load growing
binlogs, statsLogs, err = SaveBinLog(ctx, binlogs, statsLogs, err = mock_segcore.SaveBinLog(ctx,
suite.collectionID, suite.collectionID,
suite.partitionID, suite.partitionID,
suite.segmentID+1, suite.segmentID+1,
@ -174,7 +175,7 @@ func (suite *SegmentLoaderSuite) TestLoadFail() {
msgLength := 4 msgLength := 4
// Load sealed // Load sealed
binlogs, statsLogs, err := SaveBinLog(ctx, binlogs, statsLogs, err := mock_segcore.SaveBinLog(ctx,
suite.collectionID, suite.collectionID,
suite.partitionID, suite.partitionID,
suite.segmentID, suite.segmentID,
@ -211,7 +212,7 @@ func (suite *SegmentLoaderSuite) TestLoadMultipleSegments() {
// Load sealed // Load sealed
for i := 0; i < suite.segmentNum; i++ { for i := 0; i < suite.segmentNum; i++ {
segmentID := suite.segmentID + int64(i) segmentID := suite.segmentID + int64(i)
binlogs, statsLogs, err := SaveBinLog(ctx, binlogs, statsLogs, err := mock_segcore.SaveBinLog(ctx,
suite.collectionID, suite.collectionID,
suite.partitionID, suite.partitionID,
segmentID, segmentID,
@ -247,7 +248,7 @@ func (suite *SegmentLoaderSuite) TestLoadMultipleSegments() {
loadInfos = loadInfos[:0] loadInfos = loadInfos[:0]
for i := 0; i < suite.segmentNum; i++ { for i := 0; i < suite.segmentNum; i++ {
segmentID := suite.segmentID + int64(suite.segmentNum) + int64(i) segmentID := suite.segmentID + int64(suite.segmentNum) + int64(i)
binlogs, statsLogs, err := SaveBinLog(ctx, binlogs, statsLogs, err := mock_segcore.SaveBinLog(ctx,
suite.collectionID, suite.collectionID,
suite.partitionID, suite.partitionID,
segmentID, segmentID,
@ -287,7 +288,7 @@ func (suite *SegmentLoaderSuite) TestLoadWithIndex() {
// Load sealed // Load sealed
for i := 0; i < suite.segmentNum; i++ { for i := 0; i < suite.segmentNum; i++ {
segmentID := suite.segmentID + int64(i) segmentID := suite.segmentID + int64(i)
binlogs, statsLogs, err := SaveBinLog(ctx, binlogs, statsLogs, err := mock_segcore.SaveBinLog(ctx,
suite.collectionID, suite.collectionID,
suite.partitionID, suite.partitionID,
segmentID, segmentID,
@ -298,13 +299,13 @@ func (suite *SegmentLoaderSuite) TestLoadWithIndex() {
suite.NoError(err) suite.NoError(err)
vecFields := funcutil.GetVecFieldIDs(suite.schema) vecFields := funcutil.GetVecFieldIDs(suite.schema)
indexInfo, err := GenAndSaveIndex( indexInfo, err := mock_segcore.GenAndSaveIndex(
suite.collectionID, suite.collectionID,
suite.partitionID, suite.partitionID,
segmentID, segmentID,
vecFields[0], vecFields[0],
msgLength, msgLength,
IndexFaissIVFFlat, mock_segcore.IndexFaissIVFFlat,
metric.L2, metric.L2,
suite.chunkManager, suite.chunkManager,
) )
@ -338,7 +339,7 @@ func (suite *SegmentLoaderSuite) TestLoadBloomFilter() {
// Load sealed // Load sealed
for i := 0; i < suite.segmentNum; i++ { for i := 0; i < suite.segmentNum; i++ {
segmentID := suite.segmentID + int64(i) segmentID := suite.segmentID + int64(i)
binlogs, statsLogs, err := SaveBinLog(ctx, binlogs, statsLogs, err := mock_segcore.SaveBinLog(ctx,
suite.collectionID, suite.collectionID,
suite.partitionID, suite.partitionID,
segmentID, segmentID,
@ -379,7 +380,7 @@ func (suite *SegmentLoaderSuite) TestLoadDeltaLogs() {
// Load sealed // Load sealed
for i := 0; i < suite.segmentNum; i++ { for i := 0; i < suite.segmentNum; i++ {
segmentID := suite.segmentID + int64(i) segmentID := suite.segmentID + int64(i)
binlogs, statsLogs, err := SaveBinLog(ctx, binlogs, statsLogs, err := mock_segcore.SaveBinLog(ctx,
suite.collectionID, suite.collectionID,
suite.partitionID, suite.partitionID,
segmentID, segmentID,
@ -390,7 +391,7 @@ func (suite *SegmentLoaderSuite) TestLoadDeltaLogs() {
suite.NoError(err) suite.NoError(err)
// Delete PKs 1, 2 // Delete PKs 1, 2
deltaLogs, err := SaveDeltaLog(suite.collectionID, deltaLogs, err := mock_segcore.SaveDeltaLog(suite.collectionID,
suite.partitionID, suite.partitionID,
segmentID, segmentID,
suite.chunkManager, suite.chunkManager,
@ -428,13 +429,13 @@ func (suite *SegmentLoaderSuite) TestLoadDeltaLogs() {
func (suite *SegmentLoaderSuite) TestLoadBm25Stats() { func (suite *SegmentLoaderSuite) TestLoadBm25Stats() {
suite.SetupBM25() suite.SetupBM25()
msgLength := 1 msgLength := 1
sparseFieldID := simpleSparseFloatVectorField.id sparseFieldID := mock_segcore.SimpleSparseFloatVectorField.ID
loadInfos := make([]*querypb.SegmentLoadInfo, 0, suite.segmentNum) loadInfos := make([]*querypb.SegmentLoadInfo, 0, suite.segmentNum)
for i := 0; i < suite.segmentNum; i++ { for i := 0; i < suite.segmentNum; i++ {
segmentID := suite.segmentID + int64(i) segmentID := suite.segmentID + int64(i)
bm25logs, err := SaveBM25Log(suite.collectionID, suite.partitionID, segmentID, sparseFieldID, msgLength, suite.chunkManager) bm25logs, err := mock_segcore.SaveBM25Log(suite.collectionID, suite.partitionID, segmentID, sparseFieldID, msgLength, suite.chunkManager)
suite.NoError(err) suite.NoError(err)
loadInfos = append(loadInfos, &querypb.SegmentLoadInfo{ loadInfos = append(loadInfos, &querypb.SegmentLoadInfo{
@ -468,7 +469,7 @@ func (suite *SegmentLoaderSuite) TestLoadDupDeltaLogs() {
// Load sealed // Load sealed
for i := 0; i < suite.segmentNum; i++ { for i := 0; i < suite.segmentNum; i++ {
segmentID := suite.segmentID + int64(i) segmentID := suite.segmentID + int64(i)
binlogs, statsLogs, err := SaveBinLog(ctx, binlogs, statsLogs, err := mock_segcore.SaveBinLog(ctx,
suite.collectionID, suite.collectionID,
suite.partitionID, suite.partitionID,
segmentID, segmentID,
@ -479,7 +480,7 @@ func (suite *SegmentLoaderSuite) TestLoadDupDeltaLogs() {
suite.NoError(err) suite.NoError(err)
// Delete PKs 1, 2 // Delete PKs 1, 2
deltaLogs, err := SaveDeltaLog(suite.collectionID, deltaLogs, err := mock_segcore.SaveDeltaLog(suite.collectionID,
suite.partitionID, suite.partitionID,
segmentID, segmentID,
suite.chunkManager, suite.chunkManager,
@ -602,7 +603,7 @@ func (suite *SegmentLoaderSuite) TestLoadWithMmap() {
msgLength := 100 msgLength := 100
// Load sealed // Load sealed
binlogs, statsLogs, err := SaveBinLog(ctx, binlogs, statsLogs, err := mock_segcore.SaveBinLog(ctx,
suite.collectionID, suite.collectionID,
suite.partitionID, suite.partitionID,
suite.segmentID, suite.segmentID,
@ -629,7 +630,7 @@ func (suite *SegmentLoaderSuite) TestPatchEntryNum() {
msgLength := 100 msgLength := 100
segmentID := suite.segmentID segmentID := suite.segmentID
binlogs, statsLogs, err := SaveBinLog(ctx, binlogs, statsLogs, err := mock_segcore.SaveBinLog(ctx,
suite.collectionID, suite.collectionID,
suite.partitionID, suite.partitionID,
segmentID, segmentID,
@ -640,13 +641,13 @@ func (suite *SegmentLoaderSuite) TestPatchEntryNum() {
suite.NoError(err) suite.NoError(err)
vecFields := funcutil.GetVecFieldIDs(suite.schema) vecFields := funcutil.GetVecFieldIDs(suite.schema)
indexInfo, err := GenAndSaveIndex( indexInfo, err := mock_segcore.GenAndSaveIndex(
suite.collectionID, suite.collectionID,
suite.partitionID, suite.partitionID,
segmentID, segmentID,
vecFields[0], vecFields[0],
msgLength, msgLength,
IndexFaissIVFFlat, mock_segcore.IndexFaissIVFFlat,
metric.L2, metric.L2,
suite.chunkManager, suite.chunkManager,
) )
@ -690,7 +691,7 @@ func (suite *SegmentLoaderSuite) TestRunOutMemory() {
msgLength := 4 msgLength := 4
// Load sealed // Load sealed
binlogs, statsLogs, err := SaveBinLog(ctx, binlogs, statsLogs, err := mock_segcore.SaveBinLog(ctx,
suite.collectionID, suite.collectionID,
suite.partitionID, suite.partitionID,
suite.segmentID, suite.segmentID,
@ -712,7 +713,7 @@ func (suite *SegmentLoaderSuite) TestRunOutMemory() {
suite.Error(err) suite.Error(err)
// Load growing // Load growing
binlogs, statsLogs, err = SaveBinLog(ctx, binlogs, statsLogs, err = mock_segcore.SaveBinLog(ctx,
suite.collectionID, suite.collectionID,
suite.partitionID, suite.partitionID,
suite.segmentID+1, suite.segmentID+1,
@ -782,7 +783,7 @@ func (suite *SegmentLoaderDetailSuite) SetupSuite() {
suite.partitionID = rand.Int63() suite.partitionID = rand.Int63()
suite.segmentID = rand.Int63() suite.segmentID = rand.Int63()
suite.segmentNum = 5 suite.segmentNum = 5
suite.schema = GenTestCollectionSchema("test", schemapb.DataType_Int64, false) suite.schema = mock_segcore.GenTestCollectionSchema("test", schemapb.DataType_Int64, false)
} }
func (suite *SegmentLoaderDetailSuite) SetupTest() { func (suite *SegmentLoaderDetailSuite) SetupTest() {
@ -801,9 +802,9 @@ func (suite *SegmentLoaderDetailSuite) SetupTest() {
initcore.InitRemoteChunkManager(paramtable.Get()) initcore.InitRemoteChunkManager(paramtable.Get())
// Data // Data
schema := GenTestCollectionSchema("test", schemapb.DataType_Int64, false) schema := mock_segcore.GenTestCollectionSchema("test", schemapb.DataType_Int64, false)
indexMeta := GenTestIndexMeta(suite.collectionID, schema) indexMeta := mock_segcore.GenTestIndexMeta(suite.collectionID, schema)
loadMeta := &querypb.LoadMetaInfo{ loadMeta := &querypb.LoadMetaInfo{
LoadType: querypb.LoadType_LoadCollection, LoadType: querypb.LoadType_LoadCollection,
CollectionID: suite.collectionID, CollectionID: suite.collectionID,

View File

@ -9,6 +9,7 @@ import (
"github.com/stretchr/testify/suite" "github.com/stretchr/testify/suite"
"github.com/milvus-io/milvus-proto/go-api/v2/schemapb" "github.com/milvus-io/milvus-proto/go-api/v2/schemapb"
"github.com/milvus-io/milvus/internal/mocks/util/mock_segcore"
"github.com/milvus-io/milvus/internal/proto/datapb" "github.com/milvus-io/milvus/internal/proto/datapb"
"github.com/milvus-io/milvus/internal/proto/querypb" "github.com/milvus-io/milvus/internal/proto/querypb"
storage "github.com/milvus-io/milvus/internal/storage" storage "github.com/milvus-io/milvus/internal/storage"
@ -54,8 +55,8 @@ func (suite *SegmentSuite) SetupTest() {
suite.segmentID = 1 suite.segmentID = 1
suite.manager = NewManager() suite.manager = NewManager()
schema := GenTestCollectionSchema("test-reduce", schemapb.DataType_Int64, true) schema := mock_segcore.GenTestCollectionSchema("test-reduce", schemapb.DataType_Int64, true)
indexMeta := GenTestIndexMeta(suite.collectionID, schema) indexMeta := mock_segcore.GenTestIndexMeta(suite.collectionID, schema)
suite.manager.Collection.PutOrRef(suite.collectionID, suite.manager.Collection.PutOrRef(suite.collectionID,
schema, schema,
indexMeta, indexMeta,
@ -93,7 +94,7 @@ func (suite *SegmentSuite) SetupTest() {
) )
suite.Require().NoError(err) suite.Require().NoError(err)
binlogs, _, err := SaveBinLog(ctx, binlogs, _, err := mock_segcore.SaveBinLog(ctx,
suite.collectionID, suite.collectionID,
suite.partitionID, suite.partitionID,
suite.segmentID, suite.segmentID,
@ -124,7 +125,7 @@ func (suite *SegmentSuite) SetupTest() {
) )
suite.Require().NoError(err) suite.Require().NoError(err)
insertMsg, err := genInsertMsg(suite.collection, suite.partitionID, suite.growing.ID(), msgLength) insertMsg, err := mock_segcore.GenInsertMsg(suite.collection.GetCCollection(), suite.partitionID, suite.growing.ID(), msgLength)
suite.Require().NoError(err) suite.Require().NoError(err)
insertRecord, err := storage.TransferInsertMsgToInsertRecord(suite.collection.Schema(), insertMsg) insertRecord, err := storage.TransferInsertMsgToInsertRecord(suite.collection.Schema(), insertMsg)
suite.Require().NoError(err) suite.Require().NoError(err)
@ -187,9 +188,9 @@ func (suite *SegmentSuite) TestDelete() {
} }
func (suite *SegmentSuite) TestHasRawData() { func (suite *SegmentSuite) TestHasRawData() {
has := suite.growing.HasRawData(simpleFloatVecField.id) has := suite.growing.HasRawData(mock_segcore.SimpleFloatVecField.ID)
suite.True(has) suite.True(has)
has = suite.sealed.HasRawData(simpleFloatVecField.id) has = suite.sealed.HasRawData(mock_segcore.SimpleFloatVecField.ID)
suite.True(has) suite.True(has)
} }

View File

@ -60,6 +60,7 @@ import (
"github.com/milvus-io/milvus/internal/util/initcore" "github.com/milvus-io/milvus/internal/util/initcore"
"github.com/milvus-io/milvus/internal/util/searchutil/optimizers" "github.com/milvus-io/milvus/internal/util/searchutil/optimizers"
"github.com/milvus-io/milvus/internal/util/searchutil/scheduler" "github.com/milvus-io/milvus/internal/util/searchutil/scheduler"
"github.com/milvus-io/milvus/internal/util/segcore"
"github.com/milvus-io/milvus/internal/util/sessionutil" "github.com/milvus-io/milvus/internal/util/sessionutil"
"github.com/milvus-io/milvus/pkg/config" "github.com/milvus-io/milvus/pkg/config"
"github.com/milvus-io/milvus/pkg/log" "github.com/milvus-io/milvus/pkg/log"
@ -323,7 +324,7 @@ func (node *QueryNode) Init() error {
node.factory.Init(paramtable.Get()) node.factory.Init(paramtable.Get())
localRootPath := paramtable.Get().LocalStorageCfg.Path.GetValue() localRootPath := paramtable.Get().LocalStorageCfg.Path.GetValue()
localUsedSize, err := segments.GetLocalUsedSize(node.ctx, localRootPath) localUsedSize, err := segcore.GetLocalUsedSize(node.ctx, localRootPath)
if err != nil { if err != nil {
log.Warn("get local used size failed", zap.Error(err)) log.Warn("get local used size failed", zap.Error(err))
initError = err initError = err

View File

@ -32,12 +32,13 @@ import (
"github.com/milvus-io/milvus-proto/go-api/v2/commonpb" "github.com/milvus-io/milvus-proto/go-api/v2/commonpb"
"github.com/milvus-io/milvus-proto/go-api/v2/schemapb" "github.com/milvus-io/milvus-proto/go-api/v2/schemapb"
"github.com/milvus-io/milvus/internal/mocks/util/mock_segcore"
"github.com/milvus-io/milvus/internal/mocks/util/searchutil/mock_optimizers"
"github.com/milvus-io/milvus/internal/proto/datapb" "github.com/milvus-io/milvus/internal/proto/datapb"
"github.com/milvus-io/milvus/internal/proto/querypb" "github.com/milvus-io/milvus/internal/proto/querypb"
"github.com/milvus-io/milvus/internal/querynodev2/segments" "github.com/milvus-io/milvus/internal/querynodev2/segments"
"github.com/milvus-io/milvus/internal/storage" "github.com/milvus-io/milvus/internal/storage"
"github.com/milvus-io/milvus/internal/util/dependency" "github.com/milvus-io/milvus/internal/util/dependency"
"github.com/milvus-io/milvus/internal/util/searchutil/optimizers"
"github.com/milvus-io/milvus/pkg/util/etcd" "github.com/milvus-io/milvus/pkg/util/etcd"
"github.com/milvus-io/milvus/pkg/util/paramtable" "github.com/milvus-io/milvus/pkg/util/paramtable"
) )
@ -159,7 +160,7 @@ func (suite *QueryNodeSuite) TestInit_QueryHook() {
err = suite.node.Init() err = suite.node.Init()
suite.NoError(err) suite.NoError(err)
mockHook := optimizers.NewMockQueryHook(suite.T()) mockHook := mock_optimizers.NewMockQueryHook(suite.T())
suite.node.queryHook = mockHook suite.node.queryHook = mockHook
suite.node.handleQueryHookEvent() suite.node.handleQueryHookEvent()
@ -219,7 +220,7 @@ func (suite *QueryNodeSuite) TestStop() {
suite.node.manager = segments.NewManager() suite.node.manager = segments.NewManager()
schema := segments.GenTestCollectionSchema("test_stop", schemapb.DataType_Int64, true) schema := mock_segcore.GenTestCollectionSchema("test_stop", schemapb.DataType_Int64, true)
collection := segments.NewCollection(1, schema, nil, &querypb.LoadMetaInfo{ collection := segments.NewCollection(1, schema, nil, &querypb.LoadMetaInfo{
LoadType: querypb.LoadType_LoadCollection, LoadType: querypb.LoadType_LoadCollection,
}) })

View File

@ -38,6 +38,7 @@ import (
"github.com/milvus-io/milvus-proto/go-api/v2/msgpb" "github.com/milvus-io/milvus-proto/go-api/v2/msgpb"
"github.com/milvus-io/milvus-proto/go-api/v2/schemapb" "github.com/milvus-io/milvus-proto/go-api/v2/schemapb"
"github.com/milvus-io/milvus/internal/json" "github.com/milvus-io/milvus/internal/json"
"github.com/milvus-io/milvus/internal/mocks/util/mock_segcore"
"github.com/milvus-io/milvus/internal/proto/datapb" "github.com/milvus-io/milvus/internal/proto/datapb"
"github.com/milvus-io/milvus/internal/proto/indexpb" "github.com/milvus-io/milvus/internal/proto/indexpb"
"github.com/milvus-io/milvus/internal/proto/internalpb" "github.com/milvus-io/milvus/internal/proto/internalpb"
@ -262,8 +263,8 @@ func (suite *ServiceSuite) TestWatchDmChannelsInt64() {
ctx := context.Background() ctx := context.Background()
// data // data
schema := segments.GenTestCollectionSchema(suite.collectionName, schemapb.DataType_Int64, false) schema := mock_segcore.GenTestCollectionSchema(suite.collectionName, schemapb.DataType_Int64, false)
deltaLogs, err := segments.SaveDeltaLog(suite.collectionID, deltaLogs, err := mock_segcore.SaveDeltaLog(suite.collectionID,
suite.partitionIDs[0], suite.partitionIDs[0],
suite.flushedSegmentIDs[0], suite.flushedSegmentIDs[0],
suite.node.chunkManager, suite.node.chunkManager,
@ -306,7 +307,7 @@ func (suite *ServiceSuite) TestWatchDmChannelsInt64() {
PartitionIDs: suite.partitionIDs, PartitionIDs: suite.partitionIDs,
MetricType: defaultMetricType, MetricType: defaultMetricType,
}, },
IndexInfoList: segments.GenTestIndexInfoList(suite.collectionID, schema), IndexInfoList: mock_segcore.GenTestIndexInfoList(suite.collectionID, schema),
} }
// mocks // mocks
@ -331,7 +332,7 @@ func (suite *ServiceSuite) TestWatchDmChannelsVarchar() {
ctx := context.Background() ctx := context.Background()
// data // data
schema := segments.GenTestCollectionSchema(suite.collectionName, schemapb.DataType_VarChar, false) schema := mock_segcore.GenTestCollectionSchema(suite.collectionName, schemapb.DataType_VarChar, false)
req := &querypb.WatchDmChannelsRequest{ req := &querypb.WatchDmChannelsRequest{
Base: &commonpb.MsgBase{ Base: &commonpb.MsgBase{
@ -358,7 +359,7 @@ func (suite *ServiceSuite) TestWatchDmChannelsVarchar() {
PartitionIDs: suite.partitionIDs, PartitionIDs: suite.partitionIDs,
MetricType: defaultMetricType, MetricType: defaultMetricType,
}, },
IndexInfoList: segments.GenTestIndexInfoList(suite.collectionID, schema), IndexInfoList: mock_segcore.GenTestIndexInfoList(suite.collectionID, schema),
} }
// mocks // mocks
@ -383,9 +384,9 @@ func (suite *ServiceSuite) TestWatchDmChannels_Failed() {
ctx := context.Background() ctx := context.Background()
// data // data
schema := segments.GenTestCollectionSchema(suite.collectionName, schemapb.DataType_Int64, false) schema := mock_segcore.GenTestCollectionSchema(suite.collectionName, schemapb.DataType_Int64, false)
indexInfos := segments.GenTestIndexInfoList(suite.collectionID, schema) indexInfos := mock_segcore.GenTestIndexInfoList(suite.collectionID, schema)
infos := suite.genSegmentLoadInfos(schema, indexInfos) infos := suite.genSegmentLoadInfos(schema, indexInfos)
segmentInfos := lo.SliceToMap(infos, func(info *querypb.SegmentLoadInfo) (int64, *datapb.SegmentInfo) { segmentInfos := lo.SliceToMap(infos, func(info *querypb.SegmentLoadInfo) (int64, *datapb.SegmentInfo) {
@ -544,7 +545,7 @@ func (suite *ServiceSuite) genSegmentLoadInfos(schema *schemapb.CollectionSchema
partNum := len(suite.partitionIDs) partNum := len(suite.partitionIDs)
infos := make([]*querypb.SegmentLoadInfo, 0) infos := make([]*querypb.SegmentLoadInfo, 0)
for i := 0; i < segNum; i++ { for i := 0; i < segNum; i++ {
binlogs, statslogs, err := segments.SaveBinLog(ctx, binlogs, statslogs, err := mock_segcore.SaveBinLog(ctx,
suite.collectionID, suite.collectionID,
suite.partitionIDs[i%partNum], suite.partitionIDs[i%partNum],
suite.validSegmentIDs[i], suite.validSegmentIDs[i],
@ -559,7 +560,7 @@ func (suite *ServiceSuite) genSegmentLoadInfos(schema *schemapb.CollectionSchema
for offset, field := range vectorFieldSchemas { for offset, field := range vectorFieldSchemas {
indexInfo := lo.FindOrElse(indexInfos, nil, func(info *indexpb.IndexInfo) bool { return info.FieldID == field.GetFieldID() }) indexInfo := lo.FindOrElse(indexInfos, nil, func(info *indexpb.IndexInfo) bool { return info.FieldID == field.GetFieldID() })
if indexInfo != nil { if indexInfo != nil {
index, err := segments.GenAndSaveIndexV2( index, err := mock_segcore.GenAndSaveIndexV2(
suite.collectionID, suite.collectionID,
suite.partitionIDs[i%partNum], suite.partitionIDs[i%partNum],
suite.validSegmentIDs[i], suite.validSegmentIDs[i],
@ -595,8 +596,8 @@ func (suite *ServiceSuite) TestLoadSegments_Int64() {
ctx := context.Background() ctx := context.Background()
suite.TestWatchDmChannelsInt64() suite.TestWatchDmChannelsInt64()
// data // data
schema := segments.GenTestCollectionSchema(suite.collectionName, schemapb.DataType_Int64, false) schema := mock_segcore.GenTestCollectionSchema(suite.collectionName, schemapb.DataType_Int64, false)
indexInfos := segments.GenTestIndexInfoList(suite.collectionID, schema) indexInfos := mock_segcore.GenTestIndexInfoList(suite.collectionID, schema)
infos := suite.genSegmentLoadInfos(schema, indexInfos) infos := suite.genSegmentLoadInfos(schema, indexInfos)
for _, info := range infos { for _, info := range infos {
req := &querypb.LoadSegmentsRequest{ req := &querypb.LoadSegmentsRequest{
@ -624,7 +625,7 @@ func (suite *ServiceSuite) TestLoadSegments_VarChar() {
ctx := context.Background() ctx := context.Background()
suite.TestWatchDmChannelsVarchar() suite.TestWatchDmChannelsVarchar()
// data // data
schema := segments.GenTestCollectionSchema(suite.collectionName, schemapb.DataType_VarChar, false) schema := mock_segcore.GenTestCollectionSchema(suite.collectionName, schemapb.DataType_VarChar, false)
loadMeta := &querypb.LoadMetaInfo{ loadMeta := &querypb.LoadMetaInfo{
LoadType: querypb.LoadType_LoadCollection, LoadType: querypb.LoadType_LoadCollection,
CollectionID: suite.collectionID, CollectionID: suite.collectionID,
@ -661,7 +662,7 @@ func (suite *ServiceSuite) TestLoadDeltaInt64() {
ctx := context.Background() ctx := context.Background()
suite.TestLoadSegments_Int64() suite.TestLoadSegments_Int64()
// data // data
schema := segments.GenTestCollectionSchema(suite.collectionName, schemapb.DataType_Int64, false) schema := mock_segcore.GenTestCollectionSchema(suite.collectionName, schemapb.DataType_Int64, false)
req := &querypb.LoadSegmentsRequest{ req := &querypb.LoadSegmentsRequest{
Base: &commonpb.MsgBase{ Base: &commonpb.MsgBase{
MsgID: rand.Int63(), MsgID: rand.Int63(),
@ -686,7 +687,7 @@ func (suite *ServiceSuite) TestLoadDeltaVarchar() {
ctx := context.Background() ctx := context.Background()
suite.TestLoadSegments_VarChar() suite.TestLoadSegments_VarChar()
// data // data
schema := segments.GenTestCollectionSchema(suite.collectionName, schemapb.DataType_Int64, false) schema := mock_segcore.GenTestCollectionSchema(suite.collectionName, schemapb.DataType_Int64, false)
req := &querypb.LoadSegmentsRequest{ req := &querypb.LoadSegmentsRequest{
Base: &commonpb.MsgBase{ Base: &commonpb.MsgBase{
MsgID: rand.Int63(), MsgID: rand.Int63(),
@ -711,9 +712,9 @@ func (suite *ServiceSuite) TestLoadIndex_Success() {
ctx, cancel := context.WithCancel(context.Background()) ctx, cancel := context.WithCancel(context.Background())
defer cancel() defer cancel()
schema := segments.GenTestCollectionSchema(suite.collectionName, schemapb.DataType_Int64, false) schema := mock_segcore.GenTestCollectionSchema(suite.collectionName, schemapb.DataType_Int64, false)
indexInfos := segments.GenTestIndexInfoList(suite.collectionID, schema) indexInfos := mock_segcore.GenTestIndexInfoList(suite.collectionID, schema)
infos := suite.genSegmentLoadInfos(schema, indexInfos) infos := suite.genSegmentLoadInfos(schema, indexInfos)
infos = lo.Map(infos, func(info *querypb.SegmentLoadInfo, _ int) *querypb.SegmentLoadInfo { infos = lo.Map(infos, func(info *querypb.SegmentLoadInfo, _ int) *querypb.SegmentLoadInfo {
info.SegmentID = info.SegmentID + 1000 info.SegmentID = info.SegmentID + 1000
@ -782,10 +783,10 @@ func (suite *ServiceSuite) TestLoadIndex_Failed() {
ctx, cancel := context.WithCancel(context.Background()) ctx, cancel := context.WithCancel(context.Background())
defer cancel() defer cancel()
schema := segments.GenTestCollectionSchema(suite.collectionName, schemapb.DataType_Int64, false) schema := mock_segcore.GenTestCollectionSchema(suite.collectionName, schemapb.DataType_Int64, false)
suite.Run("load_non_exist_segment", func() { suite.Run("load_non_exist_segment", func() {
indexInfos := segments.GenTestIndexInfoList(suite.collectionID, schema) indexInfos := mock_segcore.GenTestIndexInfoList(suite.collectionID, schema)
infos := suite.genSegmentLoadInfos(schema, indexInfos) infos := suite.genSegmentLoadInfos(schema, indexInfos)
infos = lo.Map(infos, func(info *querypb.SegmentLoadInfo, _ int) *querypb.SegmentLoadInfo { infos = lo.Map(infos, func(info *querypb.SegmentLoadInfo, _ int) *querypb.SegmentLoadInfo {
info.SegmentID = info.SegmentID + 1000 info.SegmentID = info.SegmentID + 1000
@ -828,7 +829,7 @@ func (suite *ServiceSuite) TestLoadIndex_Failed() {
mockLoader.EXPECT().LoadIndex(mock.Anything, mock.Anything, mock.Anything, mock.Anything).Return(errors.New("mocked error")) mockLoader.EXPECT().LoadIndex(mock.Anything, mock.Anything, mock.Anything, mock.Anything).Return(errors.New("mocked error"))
indexInfos := segments.GenTestIndexInfoList(suite.collectionID, schema) indexInfos := mock_segcore.GenTestIndexInfoList(suite.collectionID, schema)
infos := suite.genSegmentLoadInfos(schema, indexInfos) infos := suite.genSegmentLoadInfos(schema, indexInfos)
req := &querypb.LoadSegmentsRequest{ req := &querypb.LoadSegmentsRequest{
Base: &commonpb.MsgBase{ Base: &commonpb.MsgBase{
@ -854,7 +855,7 @@ func (suite *ServiceSuite) TestLoadIndex_Failed() {
func (suite *ServiceSuite) TestLoadSegments_Failed() { func (suite *ServiceSuite) TestLoadSegments_Failed() {
ctx := context.Background() ctx := context.Background()
// data // data
schema := segments.GenTestCollectionSchema(suite.collectionName, schemapb.DataType_Int64, false) schema := mock_segcore.GenTestCollectionSchema(suite.collectionName, schemapb.DataType_Int64, false)
req := &querypb.LoadSegmentsRequest{ req := &querypb.LoadSegmentsRequest{
Base: &commonpb.MsgBase{ Base: &commonpb.MsgBase{
MsgID: rand.Int63(), MsgID: rand.Int63(),
@ -901,7 +902,7 @@ func (suite *ServiceSuite) TestLoadSegments_Transfer() {
delegator.EXPECT().TryCleanExcludedSegments(mock.Anything).Maybe() delegator.EXPECT().TryCleanExcludedSegments(mock.Anything).Maybe()
delegator.EXPECT().LoadSegments(mock.Anything, mock.AnythingOfType("*querypb.LoadSegmentsRequest")).Return(nil) delegator.EXPECT().LoadSegments(mock.Anything, mock.AnythingOfType("*querypb.LoadSegmentsRequest")).Return(nil)
// data // data
schema := segments.GenTestCollectionSchema(suite.collectionName, schemapb.DataType_Int64, false) schema := mock_segcore.GenTestCollectionSchema(suite.collectionName, schemapb.DataType_Int64, false)
req := &querypb.LoadSegmentsRequest{ req := &querypb.LoadSegmentsRequest{
Base: &commonpb.MsgBase{ Base: &commonpb.MsgBase{
MsgID: rand.Int63(), MsgID: rand.Int63(),
@ -923,7 +924,7 @@ func (suite *ServiceSuite) TestLoadSegments_Transfer() {
suite.Run("delegator_not_found", func() { suite.Run("delegator_not_found", func() {
// data // data
schema := segments.GenTestCollectionSchema(suite.collectionName, schemapb.DataType_Int64, false) schema := mock_segcore.GenTestCollectionSchema(suite.collectionName, schemapb.DataType_Int64, false)
req := &querypb.LoadSegmentsRequest{ req := &querypb.LoadSegmentsRequest{
Base: &commonpb.MsgBase{ Base: &commonpb.MsgBase{
MsgID: rand.Int63(), MsgID: rand.Int63(),
@ -953,7 +954,7 @@ func (suite *ServiceSuite) TestLoadSegments_Transfer() {
delegator.EXPECT().LoadSegments(mock.Anything, mock.AnythingOfType("*querypb.LoadSegmentsRequest")). delegator.EXPECT().LoadSegments(mock.Anything, mock.AnythingOfType("*querypb.LoadSegmentsRequest")).
Return(errors.New("mocked error")) Return(errors.New("mocked error"))
// data // data
schema := segments.GenTestCollectionSchema(suite.collectionName, schemapb.DataType_Int64, false) schema := mock_segcore.GenTestCollectionSchema(suite.collectionName, schemapb.DataType_Int64, false)
req := &querypb.LoadSegmentsRequest{ req := &querypb.LoadSegmentsRequest{
Base: &commonpb.MsgBase{ Base: &commonpb.MsgBase{
MsgID: rand.Int63(), MsgID: rand.Int63(),
@ -1245,7 +1246,7 @@ func (suite *ServiceSuite) TestSearch_Failed() {
ctx := context.Background() ctx := context.Background()
// data // data
schema := segments.GenTestCollectionSchema(suite.collectionName, schemapb.DataType_Int64, false) schema := mock_segcore.GenTestCollectionSchema(suite.collectionName, schemapb.DataType_Int64, false)
creq, err := suite.genCSearchRequest(10, schemapb.DataType_FloatVector, 107, "invalidMetricType", false) creq, err := suite.genCSearchRequest(10, schemapb.DataType_FloatVector, 107, "invalidMetricType", false)
req := &querypb.SearchRequest{ req := &querypb.SearchRequest{
Req: creq, Req: creq,
@ -1267,7 +1268,7 @@ func (suite *ServiceSuite) TestSearch_Failed() {
CollectionID: suite.collectionID, CollectionID: suite.collectionID,
PartitionIDs: suite.partitionIDs, PartitionIDs: suite.partitionIDs,
} }
indexMeta := suite.node.composeIndexMeta(segments.GenTestIndexInfoList(suite.collectionID, schema), schema) indexMeta := suite.node.composeIndexMeta(mock_segcore.GenTestIndexInfoList(suite.collectionID, schema), schema)
suite.node.manager.Collection.PutOrRef(suite.collectionID, schema, indexMeta, LoadMeta) suite.node.manager.Collection.PutOrRef(suite.collectionID, schema, indexMeta, LoadMeta)
// Delegator not found // Delegator not found
@ -1459,7 +1460,7 @@ func (suite *ServiceSuite) TestQuery_Normal() {
suite.TestLoadSegments_Int64() suite.TestLoadSegments_Int64()
// data // data
schema := segments.GenTestCollectionSchema(suite.collectionName, schemapb.DataType_Int64, false) schema := mock_segcore.GenTestCollectionSchema(suite.collectionName, schemapb.DataType_Int64, false)
creq, err := suite.genCQueryRequest(10, IndexFaissIDMap, schema) creq, err := suite.genCQueryRequest(10, IndexFaissIDMap, schema)
suite.NoError(err) suite.NoError(err)
req := &querypb.QueryRequest{ req := &querypb.QueryRequest{
@ -1478,7 +1479,7 @@ func (suite *ServiceSuite) TestQuery_Failed() {
defer cancel() defer cancel()
// data // data
schema := segments.GenTestCollectionSchema(suite.collectionName, schemapb.DataType_Int64, false) schema := mock_segcore.GenTestCollectionSchema(suite.collectionName, schemapb.DataType_Int64, false)
creq, err := suite.genCQueryRequest(10, IndexFaissIDMap, schema) creq, err := suite.genCQueryRequest(10, IndexFaissIDMap, schema)
suite.NoError(err) suite.NoError(err)
req := &querypb.QueryRequest{ req := &querypb.QueryRequest{
@ -1540,7 +1541,7 @@ func (suite *ServiceSuite) TestQueryStream_Normal() {
suite.TestLoadSegments_Int64() suite.TestLoadSegments_Int64()
// data // data
schema := segments.GenTestCollectionSchema(suite.collectionName, schemapb.DataType_Int64, false) schema := mock_segcore.GenTestCollectionSchema(suite.collectionName, schemapb.DataType_Int64, false)
creq, err := suite.genCQueryRequest(10, IndexFaissIDMap, schema) creq, err := suite.genCQueryRequest(10, IndexFaissIDMap, schema)
suite.NoError(err) suite.NoError(err)
req := &querypb.QueryRequest{ req := &querypb.QueryRequest{
@ -1575,7 +1576,7 @@ func (suite *ServiceSuite) TestQueryStream_Failed() {
defer cancel() defer cancel()
// data // data
schema := segments.GenTestCollectionSchema(suite.collectionName, schemapb.DataType_Int64, false) schema := mock_segcore.GenTestCollectionSchema(suite.collectionName, schemapb.DataType_Int64, false)
creq, err := suite.genCQueryRequest(10, IndexFaissIDMap, schema) creq, err := suite.genCQueryRequest(10, IndexFaissIDMap, schema)
suite.NoError(err) suite.NoError(err)
req := &querypb.QueryRequest{ req := &querypb.QueryRequest{
@ -1653,7 +1654,7 @@ func (suite *ServiceSuite) TestQuerySegments_Normal() {
suite.TestLoadSegments_Int64() suite.TestLoadSegments_Int64()
// data // data
schema := segments.GenTestCollectionSchema(suite.collectionName, schemapb.DataType_Int64, false) schema := mock_segcore.GenTestCollectionSchema(suite.collectionName, schemapb.DataType_Int64, false)
creq, err := suite.genCQueryRequest(10, IndexFaissIDMap, schema) creq, err := suite.genCQueryRequest(10, IndexFaissIDMap, schema)
suite.NoError(err) suite.NoError(err)
req := &querypb.QueryRequest{ req := &querypb.QueryRequest{
@ -1675,7 +1676,7 @@ func (suite *ServiceSuite) TestQueryStreamSegments_Normal() {
suite.TestLoadSegments_Int64() suite.TestLoadSegments_Int64()
// data // data
schema := segments.GenTestCollectionSchema(suite.collectionName, schemapb.DataType_Int64, false) schema := mock_segcore.GenTestCollectionSchema(suite.collectionName, schemapb.DataType_Int64, false)
creq, err := suite.genCQueryRequest(10, IndexFaissIDMap, schema) creq, err := suite.genCQueryRequest(10, IndexFaissIDMap, schema)
suite.NoError(err) suite.NoError(err)
req := &querypb.QueryRequest{ req := &querypb.QueryRequest{

View File

@ -7,6 +7,7 @@ import (
"github.com/milvus-io/milvus/internal/proto/querypb" "github.com/milvus-io/milvus/internal/proto/querypb"
"github.com/milvus-io/milvus/internal/querynodev2/segments" "github.com/milvus-io/milvus/internal/querynodev2/segments"
"github.com/milvus-io/milvus/internal/util/searchutil/scheduler" "github.com/milvus-io/milvus/internal/util/searchutil/scheduler"
"github.com/milvus-io/milvus/internal/util/segcore"
"github.com/milvus-io/milvus/internal/util/streamrpc" "github.com/milvus-io/milvus/internal/util/streamrpc"
) )
@ -59,9 +60,8 @@ func (t *QueryStreamTask) PreExecute() error {
} }
func (t *QueryStreamTask) Execute() error { func (t *QueryStreamTask) Execute() error {
retrievePlan, err := segments.NewRetrievePlan( retrievePlan, err := segcore.NewRetrievePlan(
t.ctx, t.collection.GetCCollection(),
t.collection,
t.req.Req.GetSerializedExprPlan(), t.req.Req.GetSerializedExprPlan(),
t.req.Req.GetMvccTimestamp(), t.req.Req.GetMvccTimestamp(),
t.req.Req.Base.GetMsgID(), t.req.Req.Base.GetMsgID(),

View File

@ -16,6 +16,7 @@ import (
"github.com/milvus-io/milvus/internal/proto/segcorepb" "github.com/milvus-io/milvus/internal/proto/segcorepb"
"github.com/milvus-io/milvus/internal/querynodev2/segments" "github.com/milvus-io/milvus/internal/querynodev2/segments"
"github.com/milvus-io/milvus/internal/util/searchutil/scheduler" "github.com/milvus-io/milvus/internal/util/searchutil/scheduler"
"github.com/milvus-io/milvus/internal/util/segcore"
"github.com/milvus-io/milvus/pkg/metrics" "github.com/milvus-io/milvus/pkg/metrics"
"github.com/milvus-io/milvus/pkg/util/merr" "github.com/milvus-io/milvus/pkg/util/merr"
"github.com/milvus-io/milvus/pkg/util/paramtable" "github.com/milvus-io/milvus/pkg/util/paramtable"
@ -100,9 +101,8 @@ func (t *QueryTask) Execute() error {
} }
tr := timerecord.NewTimeRecorderWithTrace(t.ctx, "QueryTask") tr := timerecord.NewTimeRecorderWithTrace(t.ctx, "QueryTask")
retrievePlan, err := segments.NewRetrievePlan( retrievePlan, err := segcore.NewRetrievePlan(
t.ctx, t.collection.GetCCollection(),
t.collection,
t.req.Req.GetSerializedExprPlan(), t.req.Req.GetSerializedExprPlan(),
t.req.Req.GetMvccTimestamp(), t.req.Req.GetMvccTimestamp(),
t.req.Req.Base.GetMsgID(), t.req.Req.Base.GetMsgID(),

View File

@ -21,6 +21,7 @@ import (
"github.com/milvus-io/milvus/internal/proto/querypb" "github.com/milvus-io/milvus/internal/proto/querypb"
"github.com/milvus-io/milvus/internal/querynodev2/segments" "github.com/milvus-io/milvus/internal/querynodev2/segments"
"github.com/milvus-io/milvus/internal/util/searchutil/scheduler" "github.com/milvus-io/milvus/internal/util/searchutil/scheduler"
"github.com/milvus-io/milvus/internal/util/segcore"
"github.com/milvus-io/milvus/pkg/log" "github.com/milvus-io/milvus/pkg/log"
"github.com/milvus-io/milvus/pkg/metrics" "github.com/milvus-io/milvus/pkg/metrics"
"github.com/milvus-io/milvus/pkg/util/funcutil" "github.com/milvus-io/milvus/pkg/util/funcutil"
@ -145,7 +146,7 @@ func (t *SearchTask) Execute() error {
if err != nil { if err != nil {
return err return err
} }
searchReq, err := segments.NewSearchRequest(t.ctx, t.collection, req, t.placeholderGroup) searchReq, err := segcore.NewSearchRequest(t.collection.GetCCollection(), req, t.placeholderGroup)
if err != nil { if err != nil {
return err return err
} }
@ -215,7 +216,7 @@ func (t *SearchTask) Execute() error {
}, 0) }, 0)
tr.RecordSpan() tr.RecordSpan()
blobs, err := segments.ReduceSearchResultsAndFillData( blobs, err := segcore.ReduceSearchResultsAndFillData(
t.ctx, t.ctx,
searchReq.Plan(), searchReq.Plan(),
results, results,
@ -227,7 +228,7 @@ func (t *SearchTask) Execute() error {
log.Warn("failed to reduce search results", zap.Error(err)) log.Warn("failed to reduce search results", zap.Error(err))
return err return err
} }
defer segments.DeleteSearchResultDataBlobs(blobs) defer segcore.DeleteSearchResultDataBlobs(blobs)
metrics.QueryNodeReduceLatency.WithLabelValues( metrics.QueryNodeReduceLatency.WithLabelValues(
fmt.Sprint(t.GetNodeID()), fmt.Sprint(t.GetNodeID()),
metrics.SearchLabel, metrics.SearchLabel,
@ -235,7 +236,7 @@ func (t *SearchTask) Execute() error {
metrics.BatchReduce). metrics.BatchReduce).
Observe(float64(tr.RecordSpan().Milliseconds())) Observe(float64(tr.RecordSpan().Milliseconds()))
for i := range t.originNqs { for i := range t.originNqs {
blob, err := segments.GetSearchResultDataBlob(t.ctx, blobs, i) blob, err := segcore.GetSearchResultDataBlob(t.ctx, blobs, i)
if err != nil { if err != nil {
return err return err
} }
@ -385,8 +386,8 @@ func (t *SearchTask) combinePlaceHolderGroups() error {
type StreamingSearchTask struct { type StreamingSearchTask struct {
SearchTask SearchTask
others []*StreamingSearchTask others []*StreamingSearchTask
resultBlobs segments.SearchResultDataBlobs resultBlobs segcore.SearchResultDataBlobs
streamReducer segments.StreamSearchReducer streamReducer segcore.StreamSearchReducer
} }
func NewStreamingSearchTask(ctx context.Context, func NewStreamingSearchTask(ctx context.Context,
@ -433,7 +434,7 @@ func (t *StreamingSearchTask) Execute() error {
tr := timerecord.NewTimeRecorderWithTrace(t.ctx, "SearchTask") tr := timerecord.NewTimeRecorderWithTrace(t.ctx, "SearchTask")
req := t.req req := t.req
t.combinePlaceHolderGroups() t.combinePlaceHolderGroups()
searchReq, err := segments.NewSearchRequest(t.ctx, t.collection, req, t.placeholderGroup) searchReq, err := segcore.NewSearchRequest(t.collection.GetCCollection(), req, t.placeholderGroup)
if err != nil { if err != nil {
return err return err
} }
@ -455,14 +456,14 @@ func (t *StreamingSearchTask) Execute() error {
nil, nil,
req.GetSegmentIDs(), req.GetSegmentIDs(),
streamReduceFunc) streamReduceFunc)
defer segments.DeleteStreamReduceHelper(t.streamReducer) defer segcore.DeleteStreamReduceHelper(t.streamReducer)
defer t.segmentManager.Segment.Unpin(pinnedSegments) defer t.segmentManager.Segment.Unpin(pinnedSegments)
if err != nil { if err != nil {
log.Error("Failed to search sealed segments streamly", zap.Error(err)) log.Error("Failed to search sealed segments streamly", zap.Error(err))
return err return err
} }
t.resultBlobs, err = segments.GetStreamReduceResult(t.ctx, t.streamReducer) t.resultBlobs, err = segcore.GetStreamReduceResult(t.ctx, t.streamReducer)
defer segments.DeleteSearchResultDataBlobs(t.resultBlobs) defer segcore.DeleteSearchResultDataBlobs(t.resultBlobs)
if err != nil { if err != nil {
log.Error("Failed to get stream-reduced search result") log.Error("Failed to get stream-reduced search result")
return err return err
@ -488,7 +489,7 @@ func (t *StreamingSearchTask) Execute() error {
return nil return nil
} }
tr.RecordSpan() tr.RecordSpan()
t.resultBlobs, err = segments.ReduceSearchResultsAndFillData( t.resultBlobs, err = segcore.ReduceSearchResultsAndFillData(
t.ctx, t.ctx,
searchReq.Plan(), searchReq.Plan(),
results, results,
@ -500,7 +501,7 @@ func (t *StreamingSearchTask) Execute() error {
log.Warn("failed to reduce search results", zap.Error(err)) log.Warn("failed to reduce search results", zap.Error(err))
return err return err
} }
defer segments.DeleteSearchResultDataBlobs(t.resultBlobs) defer segcore.DeleteSearchResultDataBlobs(t.resultBlobs)
metrics.QueryNodeReduceLatency.WithLabelValues( metrics.QueryNodeReduceLatency.WithLabelValues(
fmt.Sprint(t.GetNodeID()), fmt.Sprint(t.GetNodeID()),
metrics.SearchLabel, metrics.SearchLabel,
@ -514,7 +515,7 @@ func (t *StreamingSearchTask) Execute() error {
// 2. reorganize blobs to original search request // 2. reorganize blobs to original search request
for i := range t.originNqs { for i := range t.originNqs {
blob, err := segments.GetSearchResultDataBlob(t.ctx, t.resultBlobs, i) blob, err := segcore.GetSearchResultDataBlob(t.ctx, t.resultBlobs, i)
if err != nil { if err != nil {
return err return err
} }
@ -584,19 +585,19 @@ func (t *StreamingSearchTask) maybeReturnForEmptyResults(results []*segments.Sea
} }
func (t *StreamingSearchTask) streamReduce(ctx context.Context, func (t *StreamingSearchTask) streamReduce(ctx context.Context,
plan *segments.SearchPlan, plan *segcore.SearchPlan,
newResult *segments.SearchResult, newResult *segments.SearchResult,
sliceNQs []int64, sliceNQs []int64,
sliceTopKs []int64, sliceTopKs []int64,
) error { ) error {
if t.streamReducer == nil { if t.streamReducer == nil {
var err error var err error
t.streamReducer, err = segments.NewStreamReducer(ctx, plan, sliceNQs, sliceTopKs) t.streamReducer, err = segcore.NewStreamReducer(ctx, plan, sliceNQs, sliceTopKs)
if err != nil { if err != nil {
log.Error("Fail to init stream reducer, return") log.Error("Fail to init stream reducer, return")
return err return err
} }
} }
return segments.StreamReduceSearchResult(ctx, newResult, t.streamReducer) return segcore.StreamReduceSearchResult(ctx, newResult, t.streamReducer)
} }

View File

@ -8,6 +8,7 @@ import (
"github.com/stretchr/testify/suite" "github.com/stretchr/testify/suite"
"google.golang.org/protobuf/proto" "google.golang.org/protobuf/proto"
"github.com/milvus-io/milvus/internal/mocks/util/searchutil/mock_optimizers"
"github.com/milvus-io/milvus/internal/proto/internalpb" "github.com/milvus-io/milvus/internal/proto/internalpb"
"github.com/milvus-io/milvus/internal/proto/planpb" "github.com/milvus-io/milvus/internal/proto/planpb"
"github.com/milvus-io/milvus/internal/proto/querypb" "github.com/milvus-io/milvus/internal/proto/querypb"
@ -36,7 +37,7 @@ func (suite *QueryHookSuite) TestOptimizeSearchParam() {
suite.Run("normal_run", func() { suite.Run("normal_run", func() {
paramtable.Get().Save(paramtable.Get().AutoIndexConfig.Enable.Key, "true") paramtable.Get().Save(paramtable.Get().AutoIndexConfig.Enable.Key, "true")
mockHook := NewMockQueryHook(suite.T()) mockHook := mock_optimizers.NewMockQueryHook(suite.T())
mockHook.EXPECT().Run(mock.Anything).Run(func(params map[string]any) { mockHook.EXPECT().Run(mock.Anything).Run(func(params map[string]any) {
params[common.TopKKey] = int64(50) params[common.TopKKey] = int64(50)
params[common.SearchParamKey] = `{"param": 2}` params[common.SearchParamKey] = `{"param": 2}`
@ -87,7 +88,7 @@ func (suite *QueryHookSuite) TestOptimizeSearchParam() {
}) })
suite.Run("disable optimization", func() { suite.Run("disable optimization", func() {
mockHook := NewMockQueryHook(suite.T()) mockHook := mock_optimizers.NewMockQueryHook(suite.T())
suite.queryHook = mockHook suite.queryHook = mockHook
defer func() { suite.queryHook = nil }() defer func() { suite.queryHook = nil }()
@ -144,7 +145,7 @@ func (suite *QueryHookSuite) TestOptimizeSearchParam() {
suite.Run("other_plannode", func() { suite.Run("other_plannode", func() {
paramtable.Get().Save(paramtable.Get().AutoIndexConfig.Enable.Key, "true") paramtable.Get().Save(paramtable.Get().AutoIndexConfig.Enable.Key, "true")
mockHook := NewMockQueryHook(suite.T()) mockHook := mock_optimizers.NewMockQueryHook(suite.T())
mockHook.EXPECT().Run(mock.Anything).Run(func(params map[string]any) { mockHook.EXPECT().Run(mock.Anything).Run(func(params map[string]any) {
params[common.TopKKey] = int64(50) params[common.TopKKey] = int64(50)
params[common.SearchParamKey] = `{"param": 2}` params[common.SearchParamKey] = `{"param": 2}`
@ -174,7 +175,7 @@ func (suite *QueryHookSuite) TestOptimizeSearchParam() {
suite.Run("no_serialized_plan", func() { suite.Run("no_serialized_plan", func() {
paramtable.Get().Save(paramtable.Get().AutoIndexConfig.Enable.Key, "true") paramtable.Get().Save(paramtable.Get().AutoIndexConfig.Enable.Key, "true")
defer paramtable.Get().Reset(paramtable.Get().AutoIndexConfig.Enable.Key) defer paramtable.Get().Reset(paramtable.Get().AutoIndexConfig.Enable.Key)
mockHook := NewMockQueryHook(suite.T()) mockHook := mock_optimizers.NewMockQueryHook(suite.T())
suite.queryHook = mockHook suite.queryHook = mockHook
defer func() { suite.queryHook = nil }() defer func() { suite.queryHook = nil }()
@ -187,7 +188,7 @@ func (suite *QueryHookSuite) TestOptimizeSearchParam() {
suite.Run("hook_run_error", func() { suite.Run("hook_run_error", func() {
paramtable.Get().Save(paramtable.Get().AutoIndexConfig.Enable.Key, "true") paramtable.Get().Save(paramtable.Get().AutoIndexConfig.Enable.Key, "true")
mockHook := NewMockQueryHook(suite.T()) mockHook := mock_optimizers.NewMockQueryHook(suite.T())
mockHook.EXPECT().Run(mock.Anything).Run(func(params map[string]any) { mockHook.EXPECT().Run(mock.Anything).Run(func(params map[string]any) {
params[common.TopKKey] = int64(50) params[common.TopKKey] = int64(50)
params[common.SearchParamKey] = `{"param": 2}` params[common.SearchParamKey] = `{"param": 2}`

View File

@ -0,0 +1,78 @@
// Licensed to the LF AI & Data foundation under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package segcore
/*
#cgo pkg-config: milvus_core
#include "segcore/collection_c.h"
#include "common/type_c.h"
#include "segcore/segment_c.h"
#include "storage/storage_c.h"
*/
import "C"
import (
"context"
"math"
"unsafe"
"google.golang.org/protobuf/proto"
"github.com/milvus-io/milvus/internal/util/cgoconverter"
"github.com/milvus-io/milvus/pkg/util/merr"
)
type CStatus = C.CStatus
// ConsumeCStatusIntoError consumes the CStatus and returns the error
func ConsumeCStatusIntoError(status *C.CStatus) error {
if status == nil || status.error_code == 0 {
return nil
}
errorCode := status.error_code
errorMsg := C.GoString(status.error_msg)
C.free(unsafe.Pointer(status.error_msg))
return merr.SegcoreError(int32(errorCode), errorMsg)
}
// unmarshalCProto unmarshal the proto from C memory
func unmarshalCProto(cRes *C.CProto, msg proto.Message) error {
blob := (*(*[math.MaxInt32]byte)(cRes.proto_blob))[:int(cRes.proto_size):int(cRes.proto_size)]
return proto.Unmarshal(blob, msg)
}
// getCProtoBlob returns the raw C memory, invoker should release it itself
func getCProtoBlob(cProto *C.CProto) []byte {
lease, blob := cgoconverter.UnsafeGoBytes(&cProto.proto_blob, int(cProto.proto_size))
cgoconverter.Extract(lease)
return blob
}
// GetLocalUsedSize returns the used size of the local path
func GetLocalUsedSize(ctx context.Context, path string) (int64, error) {
var availableSize int64
cSize := (*C.int64_t)(&availableSize)
cPath := C.CString(path)
defer C.free(unsafe.Pointer(cPath))
status := C.GetLocalUsedSize(cPath, cSize)
if err := ConsumeCStatusIntoError(&status); err != nil {
return 0, err
}
return availableSize, nil
}

View File

@ -0,0 +1,19 @@
package segcore
import (
"context"
"testing"
"github.com/stretchr/testify/assert"
)
func TestConsumeCStatusIntoError(t *testing.T) {
err := ConsumeCStatusIntoError(nil)
assert.NoError(t, err)
}
func TestGetLocalUsedSize(t *testing.T) {
size, err := GetLocalUsedSize(context.Background(), "")
assert.NoError(t, err)
assert.NotNil(t, size)
}

View File

@ -0,0 +1,84 @@
package segcore
/*
#cgo pkg-config: milvus_core
#include "segcore/collection_c.h"
#include "segcore/segment_c.h"
*/
import "C"
import (
"unsafe"
"github.com/cockroachdb/errors"
"google.golang.org/protobuf/proto"
"github.com/milvus-io/milvus-proto/go-api/v2/schemapb"
"github.com/milvus-io/milvus/internal/proto/segcorepb"
)
// CreateCCollectionRequest is a request to create a CCollection.
type CreateCCollectionRequest struct {
CollectionID int64
Schema *schemapb.CollectionSchema
IndexMeta *segcorepb.CollectionIndexMeta
}
// CreateCCollection creates a CCollection from a CreateCCollectionRequest.
func CreateCCollection(req *CreateCCollectionRequest) (*CCollection, error) {
schemaBlob, err := proto.Marshal(req.Schema)
if err != nil {
return nil, errors.New("marshal schema failed")
}
var indexMetaBlob []byte
if req.IndexMeta != nil {
indexMetaBlob, err = proto.Marshal(req.IndexMeta)
if err != nil {
return nil, errors.New("marshal index meta failed")
}
}
ptr := C.NewCollection(unsafe.Pointer(&schemaBlob[0]), (C.int64_t)(len(schemaBlob)))
if indexMetaBlob != nil {
C.SetIndexMeta(ptr, unsafe.Pointer(&indexMetaBlob[0]), (C.int64_t)(len(indexMetaBlob)))
}
return &CCollection{
collectionID: req.CollectionID,
ptr: ptr,
schema: req.Schema,
indexMeta: req.IndexMeta,
}, nil
}
// CCollection is just a wrapper of the underlying C-structure CCollection.
// Contains some additional immutable properties of collection.
type CCollection struct {
ptr C.CCollection
collectionID int64
schema *schemapb.CollectionSchema
indexMeta *segcorepb.CollectionIndexMeta
}
// ID returns the collection ID.
func (c *CCollection) ID() int64 {
return c.collectionID
}
// rawPointer returns the underlying C-structure pointer.
func (c *CCollection) rawPointer() C.CCollection {
return c.ptr
}
func (c *CCollection) Schema() *schemapb.CollectionSchema {
return c.schema
}
func (c *CCollection) IndexMeta() *segcorepb.CollectionIndexMeta {
return c.indexMeta
}
// Release releases the underlying collection
func (c *CCollection) Release() {
C.DeleteCollection(c.ptr)
c.ptr = nil
}

View File

@ -0,0 +1,29 @@
package segcore_test
import (
"testing"
"github.com/stretchr/testify/assert"
"github.com/milvus-io/milvus-proto/go-api/v2/schemapb"
"github.com/milvus-io/milvus/internal/mocks/util/mock_segcore"
"github.com/milvus-io/milvus/internal/util/segcore"
"github.com/milvus-io/milvus/pkg/util/paramtable"
)
func TestCollection(t *testing.T) {
paramtable.Init()
schema := mock_segcore.GenTestCollectionSchema("test", schemapb.DataType_Int64, false)
indexMeta := mock_segcore.GenTestIndexMeta(1, schema)
ccollection, err := segcore.CreateCCollection(&segcore.CreateCCollectionRequest{
CollectionID: 1,
Schema: schema,
IndexMeta: indexMeta,
})
assert.NoError(t, err)
assert.NotNil(t, ccollection)
assert.NotNil(t, ccollection.Schema())
assert.NotNil(t, ccollection.IndexMeta())
assert.Equal(t, int64(1), ccollection.ID())
defer ccollection.Release()
}

View File

@ -14,11 +14,12 @@
// See the License for the specific language governing permissions and // See the License for the specific language governing permissions and
// limitations under the License. // limitations under the License.
package segments package segcore
/* /*
#cgo pkg-config: milvus_core #cgo pkg-config: milvus_core
#include "common/type_c.h"
#include "segcore/collection_c.h" #include "segcore/collection_c.h"
#include "segcore/segment_c.h" #include "segcore/segment_c.h"
#include "segcore/plan_c.h" #include "segcore/plan_c.h"
@ -26,15 +27,14 @@ package segments
import "C" import "C"
import ( import (
"context"
"fmt"
"unsafe" "unsafe"
"github.com/cockroachdb/errors" "github.com/cockroachdb/errors"
"github.com/milvus-io/milvus/internal/proto/querypb" "github.com/milvus-io/milvus/internal/proto/querypb"
"github.com/milvus-io/milvus/pkg/util/merr" "github.com/milvus-io/milvus/pkg/util/merr"
. "github.com/milvus-io/milvus/pkg/util/typeutil" "github.com/milvus-io/milvus/pkg/util/paramtable"
"github.com/milvus-io/milvus/pkg/util/typeutil"
) )
// SearchPlan is a wrapper of the underlying C-structure C.CSearchPlan // SearchPlan is a wrapper of the underlying C-structure C.CSearchPlan
@ -42,22 +42,16 @@ type SearchPlan struct {
cSearchPlan C.CSearchPlan cSearchPlan C.CSearchPlan
} }
func createSearchPlanByExpr(ctx context.Context, col *Collection, expr []byte) (*SearchPlan, error) { func createSearchPlanByExpr(col *CCollection, expr []byte) (*SearchPlan, error) {
if col.collectionPtr == nil {
return nil, errors.New("nil collection ptr, collectionID = " + fmt.Sprintln(col.id))
}
var cPlan C.CSearchPlan var cPlan C.CSearchPlan
status := C.CreateSearchPlanByExpr(col.collectionPtr, unsafe.Pointer(&expr[0]), (C.int64_t)(len(expr)), &cPlan) status := C.CreateSearchPlanByExpr(col.rawPointer(), unsafe.Pointer(&expr[0]), (C.int64_t)(len(expr)), &cPlan)
if err := ConsumeCStatusIntoError(&status); err != nil {
err1 := HandleCStatus(ctx, &status, "Create Plan by expr failed") return nil, errors.Wrap(err, "Create Plan by expr failed")
if err1 != nil {
return nil, err1
} }
return &SearchPlan{cSearchPlan: cPlan}, nil return &SearchPlan{cSearchPlan: cPlan}, nil
} }
func (plan *SearchPlan) getTopK() int64 { func (plan *SearchPlan) GetTopK() int64 {
topK := C.GetTopK(plan.cSearchPlan) topK := C.GetTopK(plan.cSearchPlan)
return int64(topK) return int64(topK)
} }
@ -82,15 +76,15 @@ func (plan *SearchPlan) delete() {
type SearchRequest struct { type SearchRequest struct {
plan *SearchPlan plan *SearchPlan
cPlaceholderGroup C.CPlaceholderGroup cPlaceholderGroup C.CPlaceholderGroup
msgID UniqueID msgID int64
searchFieldID UniqueID searchFieldID int64
mvccTimestamp Timestamp mvccTimestamp typeutil.Timestamp
} }
func NewSearchRequest(ctx context.Context, collection *Collection, req *querypb.SearchRequest, placeholderGrp []byte) (*SearchRequest, error) { func NewSearchRequest(collection *CCollection, req *querypb.SearchRequest, placeholderGrp []byte) (*SearchRequest, error) {
metricType := req.GetReq().GetMetricType() metricType := req.GetReq().GetMetricType()
expr := req.Req.SerializedExprPlan expr := req.Req.SerializedExprPlan
plan, err := createSearchPlanByExpr(ctx, collection, expr) plan, err := createSearchPlanByExpr(collection, expr)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -104,10 +98,9 @@ func NewSearchRequest(ctx context.Context, collection *Collection, req *querypb.
blobSize := C.int64_t(len(placeholderGrp)) blobSize := C.int64_t(len(placeholderGrp))
var cPlaceholderGroup C.CPlaceholderGroup var cPlaceholderGroup C.CPlaceholderGroup
status := C.ParsePlaceholderGroup(plan.cSearchPlan, blobPtr, blobSize, &cPlaceholderGroup) status := C.ParsePlaceholderGroup(plan.cSearchPlan, blobPtr, blobSize, &cPlaceholderGroup)
if err := ConsumeCStatusIntoError(&status); err != nil {
if err := HandleCStatus(ctx, &status, "parser searchRequest failed"); err != nil {
plan.delete() plan.delete()
return nil, err return nil, errors.Wrap(err, "parser searchRequest failed")
} }
metricTypeInPlan := plan.GetMetricType() metricTypeInPlan := plan.GetMetricType()
@ -118,23 +111,21 @@ func NewSearchRequest(ctx context.Context, collection *Collection, req *querypb.
var fieldID C.int64_t var fieldID C.int64_t
status = C.GetFieldID(plan.cSearchPlan, &fieldID) status = C.GetFieldID(plan.cSearchPlan, &fieldID)
if err = HandleCStatus(ctx, &status, "get fieldID from plan failed"); err != nil { if err := ConsumeCStatusIntoError(&status); err != nil {
plan.delete() plan.delete()
return nil, err return nil, errors.Wrap(err, "get fieldID from plan failed")
} }
ret := &SearchRequest{ return &SearchRequest{
plan: plan, plan: plan,
cPlaceholderGroup: cPlaceholderGroup, cPlaceholderGroup: cPlaceholderGroup,
msgID: req.GetReq().GetBase().GetMsgID(), msgID: req.GetReq().GetBase().GetMsgID(),
searchFieldID: int64(fieldID), searchFieldID: int64(fieldID),
mvccTimestamp: req.GetReq().GetMvccTimestamp(), mvccTimestamp: req.GetReq().GetMvccTimestamp(),
} }, nil
return ret, nil
} }
func (req *SearchRequest) getNumOfQuery() int64 { func (req *SearchRequest) GetNumOfQuery() int64 {
numQueries := C.GetNumOfQueries(req.cPlaceholderGroup) numQueries := C.GetNumOfQueries(req.cPlaceholderGroup)
return int64(numQueries) return int64(numQueries)
} }
@ -143,6 +134,10 @@ func (req *SearchRequest) Plan() *SearchPlan {
return req.plan return req.plan
} }
func (req *SearchRequest) SearchFieldID() int64 {
return req.searchFieldID
}
func (req *SearchRequest) Delete() { func (req *SearchRequest) Delete() {
if req.plan != nil { if req.plan != nil {
req.plan.delete() req.plan.delete()
@ -150,59 +145,49 @@ func (req *SearchRequest) Delete() {
C.DeletePlaceholderGroup(req.cPlaceholderGroup) C.DeletePlaceholderGroup(req.cPlaceholderGroup)
} }
func parseSearchRequest(ctx context.Context, plan *SearchPlan, searchRequestBlob []byte) (*SearchRequest, error) {
if len(searchRequestBlob) == 0 {
return nil, fmt.Errorf("empty search request")
}
blobPtr := unsafe.Pointer(&searchRequestBlob[0])
blobSize := C.int64_t(len(searchRequestBlob))
var cPlaceholderGroup C.CPlaceholderGroup
status := C.ParsePlaceholderGroup(plan.cSearchPlan, blobPtr, blobSize, &cPlaceholderGroup)
if err := HandleCStatus(ctx, &status, "parser searchRequest failed"); err != nil {
return nil, err
}
ret := &SearchRequest{cPlaceholderGroup: cPlaceholderGroup, plan: plan}
return ret, nil
}
// RetrievePlan is a wrapper of the underlying C-structure C.CRetrievePlan // RetrievePlan is a wrapper of the underlying C-structure C.CRetrievePlan
type RetrievePlan struct { type RetrievePlan struct {
cRetrievePlan C.CRetrievePlan cRetrievePlan C.CRetrievePlan
Timestamp Timestamp Timestamp typeutil.Timestamp
msgID UniqueID // only used to debug. msgID int64 // only used to debug.
maxLimitSize int64
ignoreNonPk bool ignoreNonPk bool
} }
func NewRetrievePlan(ctx context.Context, col *Collection, expr []byte, timestamp Timestamp, msgID UniqueID) (*RetrievePlan, error) { func NewRetrievePlan(col *CCollection, expr []byte, timestamp typeutil.Timestamp, msgID int64) (*RetrievePlan, error) {
col.mu.RLock() if col.rawPointer() == nil {
defer col.mu.RUnlock() return nil, errors.New("collection is released")
if col.collectionPtr == nil {
return nil, merr.WrapErrCollectionNotFound(col.id, "collection released")
} }
var cPlan C.CRetrievePlan var cPlan C.CRetrievePlan
status := C.CreateRetrievePlanByExpr(col.collectionPtr, unsafe.Pointer(&expr[0]), (C.int64_t)(len(expr)), &cPlan) status := C.CreateRetrievePlanByExpr(col.rawPointer(), unsafe.Pointer(&expr[0]), (C.int64_t)(len(expr)), &cPlan)
if err := ConsumeCStatusIntoError(&status); err != nil {
err := HandleCStatus(ctx, &status, "Create retrieve plan by expr failed") return nil, errors.Wrap(err, "Create retrieve plan by expr failed")
if err != nil {
return nil, err
} }
maxLimitSize := paramtable.Get().QuotaConfig.MaxOutputSize.GetAsInt64()
newPlan := &RetrievePlan{ return &RetrievePlan{
cRetrievePlan: cPlan, cRetrievePlan: cPlan,
Timestamp: timestamp, Timestamp: timestamp,
msgID: msgID, msgID: msgID,
} maxLimitSize: maxLimitSize,
return newPlan, nil }, nil
} }
func (plan *RetrievePlan) ShouldIgnoreNonPk() bool { func (plan *RetrievePlan) ShouldIgnoreNonPk() bool {
return bool(C.ShouldIgnoreNonPk(plan.cRetrievePlan)) return bool(C.ShouldIgnoreNonPk(plan.cRetrievePlan))
} }
func (plan *RetrievePlan) SetIgnoreNonPk(ignore bool) {
plan.ignoreNonPk = ignore
}
func (plan *RetrievePlan) IsIgnoreNonPk() bool {
return plan.ignoreNonPk
}
func (plan *RetrievePlan) MsgID() int64 {
return plan.msgID
}
func (plan *RetrievePlan) Delete() { func (plan *RetrievePlan) Delete() {
C.DeleteRetrievePlan(plan.cRetrievePlan) C.DeleteRetrievePlan(plan.cRetrievePlan)
} }

View File

@ -14,18 +14,20 @@
// See the License for the specific language governing permissions and // See the License for the specific language governing permissions and
// limitations under the License. // limitations under the License.
package segments package segcore_test
import ( import (
"context"
"testing" "testing"
"github.com/stretchr/testify/suite" "github.com/stretchr/testify/suite"
"google.golang.org/protobuf/proto" "google.golang.org/protobuf/proto"
"github.com/milvus-io/milvus-proto/go-api/v2/schemapb" "github.com/milvus-io/milvus-proto/go-api/v2/schemapb"
"github.com/milvus-io/milvus/internal/mocks/util/mock_segcore"
"github.com/milvus-io/milvus/internal/proto/internalpb"
"github.com/milvus-io/milvus/internal/proto/planpb" "github.com/milvus-io/milvus/internal/proto/planpb"
"github.com/milvus-io/milvus/internal/proto/querypb" "github.com/milvus-io/milvus/internal/proto/querypb"
"github.com/milvus-io/milvus/internal/util/segcore"
"github.com/milvus-io/milvus/pkg/util/paramtable" "github.com/milvus-io/milvus/pkg/util/paramtable"
) )
@ -36,47 +38,46 @@ type PlanSuite struct {
collectionID int64 collectionID int64
partitionID int64 partitionID int64
segmentID int64 segmentID int64
collection *Collection collection *segcore.CCollection
} }
func (suite *PlanSuite) SetupTest() { func (suite *PlanSuite) SetupTest() {
suite.collectionID = 100 suite.collectionID = 100
suite.partitionID = 10 suite.partitionID = 10
suite.segmentID = 1 suite.segmentID = 1
schema := GenTestCollectionSchema("plan-suite", schemapb.DataType_Int64, true) schema := mock_segcore.GenTestCollectionSchema("plan-suite", schemapb.DataType_Int64, true)
suite.collection = NewCollection(suite.collectionID, schema, GenTestIndexMeta(suite.collectionID, schema), &querypb.LoadMetaInfo{ var err error
LoadType: querypb.LoadType_LoadCollection, suite.collection, err = segcore.CreateCCollection(&segcore.CreateCCollectionRequest{
Schema: schema,
IndexMeta: mock_segcore.GenTestIndexMeta(suite.collectionID, schema),
}) })
suite.collection.AddPartition(suite.partitionID) if err != nil {
panic(err)
}
} }
func (suite *PlanSuite) TearDownTest() { func (suite *PlanSuite) TearDownTest() {
DeleteCollection(suite.collection) suite.collection.Release()
} }
func (suite *PlanSuite) TestPlanCreateByExpr() { func (suite *PlanSuite) TestPlanCreateByExpr() {
planNode := &planpb.PlanNode{ planNode := &planpb.PlanNode{
OutputFieldIds: []int64{rowIDFieldID}, OutputFieldIds: []int64{0},
} }
expr, err := proto.Marshal(planNode) expr, err := proto.Marshal(planNode)
suite.NoError(err) suite.NoError(err)
_, err = createSearchPlanByExpr(context.Background(), suite.collection, expr) _, err = segcore.NewSearchRequest(suite.collection, &querypb.SearchRequest{
suite.Error(err) Req: &internalpb.SearchRequest{
} SerializedExprPlan: expr,
},
func (suite *PlanSuite) TestPlanFail() { }, nil)
collection := &Collection{
id: -1,
}
_, err := createSearchPlanByExpr(context.Background(), collection, nil)
suite.Error(err) suite.Error(err)
} }
func (suite *PlanSuite) TestQueryPlanCollectionReleased() { func (suite *PlanSuite) TestQueryPlanCollectionReleased() {
collection := &Collection{id: suite.collectionID} suite.collection.Release()
_, err := NewRetrievePlan(context.Background(), collection, nil, 0, 0) _, err := segcore.NewRetrievePlan(suite.collection, nil, 0, 0)
suite.Error(err) suite.Error(err)
} }

View File

@ -14,7 +14,7 @@
// See the License for the specific language governing permissions and // See the License for the specific language governing permissions and
// limitations under the License. // limitations under the License.
package segments package segcore
/* /*
#cgo pkg-config: milvus_core #cgo pkg-config: milvus_core
@ -27,6 +27,8 @@ import "C"
import ( import (
"context" "context"
"fmt" "fmt"
"github.com/cockroachdb/errors"
) )
type SliceInfo struct { type SliceInfo struct {
@ -34,22 +36,12 @@ type SliceInfo struct {
SliceTopKs []int64 SliceTopKs []int64
} }
// SearchResult contains a pointer to the search result in C++ memory
type SearchResult struct {
cSearchResult C.CSearchResult
}
// SearchResultDataBlobs is the CSearchResultsDataBlobs in C++ // SearchResultDataBlobs is the CSearchResultsDataBlobs in C++
type ( type (
SearchResultDataBlobs = C.CSearchResultDataBlobs SearchResultDataBlobs = C.CSearchResultDataBlobs
StreamSearchReducer = C.CSearchStreamReducer StreamSearchReducer = C.CSearchStreamReducer
) )
// RetrieveResult contains a pointer to the retrieve result in C++ memory
type RetrieveResult struct {
cRetrieveResult C.CRetrieveResult
}
func ParseSliceInfo(originNQs []int64, originTopKs []int64, nqPerSlice int64) *SliceInfo { func ParseSliceInfo(originNQs []int64, originTopKs []int64, nqPerSlice int64) *SliceInfo {
sInfo := &SliceInfo{ sInfo := &SliceInfo{
SliceNQs: make([]int64, 0), SliceNQs: make([]int64, 0),
@ -94,8 +86,8 @@ func NewStreamReducer(ctx context.Context,
var streamReducer StreamSearchReducer var streamReducer StreamSearchReducer
status := C.NewStreamReducer(plan.cSearchPlan, cSliceNQSPtr, cSliceTopKSPtr, cNumSlices, &streamReducer) status := C.NewStreamReducer(plan.cSearchPlan, cSliceNQSPtr, cSliceTopKSPtr, cNumSlices, &streamReducer)
if err := HandleCStatus(ctx, &status, "MergeSearchResultsWithOutputFields failed"); err != nil { if err := ConsumeCStatusIntoError(&status); err != nil {
return nil, err return nil, errors.Wrap(err, "MergeSearchResultsWithOutputFields failed")
} }
return streamReducer, nil return streamReducer, nil
} }
@ -108,8 +100,8 @@ func StreamReduceSearchResult(ctx context.Context,
cSearchResultPtr := &cSearchResults[0] cSearchResultPtr := &cSearchResults[0]
status := C.StreamReduce(streamReducer, cSearchResultPtr, 1) status := C.StreamReduce(streamReducer, cSearchResultPtr, 1)
if err := HandleCStatus(ctx, &status, "StreamReduceSearchResult failed"); err != nil { if err := ConsumeCStatusIntoError(&status); err != nil {
return err return errors.Wrap(err, "StreamReduceSearchResult failed")
} }
return nil return nil
} }
@ -117,8 +109,8 @@ func StreamReduceSearchResult(ctx context.Context,
func GetStreamReduceResult(ctx context.Context, streamReducer StreamSearchReducer) (SearchResultDataBlobs, error) { func GetStreamReduceResult(ctx context.Context, streamReducer StreamSearchReducer) (SearchResultDataBlobs, error) {
var cSearchResultDataBlobs SearchResultDataBlobs var cSearchResultDataBlobs SearchResultDataBlobs
status := C.GetStreamReduceResult(streamReducer, &cSearchResultDataBlobs) status := C.GetStreamReduceResult(streamReducer, &cSearchResultDataBlobs)
if err := HandleCStatus(ctx, &status, "ReduceSearchResultsAndFillData failed"); err != nil { if err := ConsumeCStatusIntoError(&status); err != nil {
return nil, err return nil, errors.Wrap(err, "ReduceSearchResultsAndFillData failed")
} }
return cSearchResultDataBlobs, nil return cSearchResultDataBlobs, nil
} }
@ -154,8 +146,8 @@ func ReduceSearchResultsAndFillData(ctx context.Context, plan *SearchPlan, searc
traceCtx := ParseCTraceContext(ctx) traceCtx := ParseCTraceContext(ctx)
status := C.ReduceSearchResultsAndFillData(traceCtx.ctx, &cSearchResultDataBlobs, plan.cSearchPlan, cSearchResultPtr, status := C.ReduceSearchResultsAndFillData(traceCtx.ctx, &cSearchResultDataBlobs, plan.cSearchPlan, cSearchResultPtr,
cNumSegments, cSliceNQSPtr, cSliceTopKSPtr, cNumSlices) cNumSegments, cSliceNQSPtr, cSliceTopKSPtr, cNumSlices)
if err := HandleCStatus(ctx, &status, "ReduceSearchResultsAndFillData failed"); err != nil { if err := ConsumeCStatusIntoError(&status); err != nil {
return nil, err return nil, errors.Wrap(err, "ReduceSearchResultsAndFillData failed")
} }
return cSearchResultDataBlobs, nil return cSearchResultDataBlobs, nil
} }
@ -163,10 +155,10 @@ func ReduceSearchResultsAndFillData(ctx context.Context, plan *SearchPlan, searc
func GetSearchResultDataBlob(ctx context.Context, cSearchResultDataBlobs SearchResultDataBlobs, blobIndex int) ([]byte, error) { func GetSearchResultDataBlob(ctx context.Context, cSearchResultDataBlobs SearchResultDataBlobs, blobIndex int) ([]byte, error) {
var blob C.CProto var blob C.CProto
status := C.GetSearchResultDataBlob(&blob, cSearchResultDataBlobs, C.int32_t(blobIndex)) status := C.GetSearchResultDataBlob(&blob, cSearchResultDataBlobs, C.int32_t(blobIndex))
if err := HandleCStatus(ctx, &status, "marshal failed"); err != nil { if err := ConsumeCStatusIntoError(&status); err != nil {
return nil, err return nil, errors.Wrap(err, "marshal failed")
} }
return GetCProtoBlob(&blob), nil return getCProtoBlob(&blob), nil
} }
func DeleteSearchResultDataBlobs(cSearchResultDataBlobs SearchResultDataBlobs) { func DeleteSearchResultDataBlobs(cSearchResultDataBlobs SearchResultDataBlobs) {
@ -176,14 +168,3 @@ func DeleteSearchResultDataBlobs(cSearchResultDataBlobs SearchResultDataBlobs) {
func DeleteStreamReduceHelper(cStreamReduceHelper StreamSearchReducer) { func DeleteStreamReduceHelper(cStreamReduceHelper StreamSearchReducer) {
C.DeleteStreamSearchReducer(cStreamReduceHelper) C.DeleteStreamSearchReducer(cStreamReduceHelper)
} }
func DeleteSearchResults(results []*SearchResult) {
if len(results) == 0 {
return
}
for _, result := range results {
if result != nil {
C.DeleteSearchResult(result.cSearchResult)
}
}
}

View File

@ -14,13 +14,14 @@
// See the License for the specific language governing permissions and // See the License for the specific language governing permissions and
// limitations under the License. // limitations under the License.
package segments package segcore_test
import ( import (
"context" "context"
"fmt" "fmt"
"log" "log"
"math" "math"
"path/filepath"
"testing" "testing"
"github.com/stretchr/testify/suite" "github.com/stretchr/testify/suite"
@ -29,11 +30,13 @@ import (
"github.com/milvus-io/milvus-proto/go-api/v2/commonpb" "github.com/milvus-io/milvus-proto/go-api/v2/commonpb"
"github.com/milvus-io/milvus-proto/go-api/v2/schemapb" "github.com/milvus-io/milvus-proto/go-api/v2/schemapb"
"github.com/milvus-io/milvus/internal/proto/datapb" "github.com/milvus-io/milvus/internal/mocks/util/mock_segcore"
"github.com/milvus-io/milvus/internal/proto/internalpb"
"github.com/milvus-io/milvus/internal/proto/planpb" "github.com/milvus-io/milvus/internal/proto/planpb"
"github.com/milvus-io/milvus/internal/proto/querypb" "github.com/milvus-io/milvus/internal/proto/querypb"
storage "github.com/milvus-io/milvus/internal/storage" "github.com/milvus-io/milvus/internal/storage"
"github.com/milvus-io/milvus/internal/util/initcore" "github.com/milvus-io/milvus/internal/util/initcore"
"github.com/milvus-io/milvus/internal/util/segcore"
"github.com/milvus-io/milvus/pkg/common" "github.com/milvus-io/milvus/pkg/common"
"github.com/milvus-io/milvus/pkg/util/funcutil" "github.com/milvus-io/milvus/pkg/util/funcutil"
"github.com/milvus-io/milvus/pkg/util/paramtable" "github.com/milvus-io/milvus/pkg/util/paramtable"
@ -49,8 +52,8 @@ type ReduceSuite struct {
collectionID int64 collectionID int64
partitionID int64 partitionID int64
segmentID int64 segmentID int64
collection *Collection collection *segcore.CCollection
segment Segment segment segcore.CSegment
} }
func (suite *ReduceSuite) SetupSuite() { func (suite *ReduceSuite) SetupSuite() {
@ -58,7 +61,10 @@ func (suite *ReduceSuite) SetupSuite() {
} }
func (suite *ReduceSuite) SetupTest() { func (suite *ReduceSuite) SetupTest() {
var err error localDataRootPath := filepath.Join(paramtable.Get().LocalStorageCfg.Path.GetValue(), typeutil.QueryNodeRole)
initcore.InitLocalChunkManager(localDataRootPath)
err := initcore.InitMmapManager(paramtable.Get())
suite.NoError(err)
ctx := context.Background() ctx := context.Background()
msgLength := 100 msgLength := 100
@ -70,29 +76,22 @@ func (suite *ReduceSuite) SetupTest() {
suite.collectionID = 100 suite.collectionID = 100
suite.partitionID = 10 suite.partitionID = 10
suite.segmentID = 1 suite.segmentID = 1
schema := GenTestCollectionSchema("test-reduce", schemapb.DataType_Int64, true) schema := mock_segcore.GenTestCollectionSchema("test-reduce", schemapb.DataType_Int64, true)
suite.collection = NewCollection(suite.collectionID, suite.collection, err = segcore.CreateCCollection(&segcore.CreateCCollectionRequest{
schema, CollectionID: suite.collectionID,
GenTestIndexMeta(suite.collectionID, schema), Schema: schema,
&querypb.LoadMetaInfo{ IndexMeta: mock_segcore.GenTestIndexMeta(suite.collectionID, schema),
LoadType: querypb.LoadType_LoadCollection, })
}) suite.NoError(err)
suite.segment, err = NewSegment(ctx, suite.segment, err = segcore.CreateCSegment(&segcore.CreateCSegmentRequest{
suite.collection, Collection: suite.collection,
SegmentTypeSealed, SegmentID: suite.segmentID,
0, SegmentType: segcore.SegmentTypeSealed,
&querypb.SegmentLoadInfo{ IsSorted: false,
SegmentID: suite.segmentID, })
CollectionID: suite.collectionID,
PartitionID: suite.partitionID,
NumOfRows: int64(msgLength),
InsertChannel: fmt.Sprintf("by-dev-rootcoord-dml_0_%dv0", suite.collectionID),
Level: datapb.SegmentLevel_Legacy,
},
)
suite.Require().NoError(err) suite.Require().NoError(err)
binlogs, _, err := SaveBinLog(ctx, binlogs, _, err := mock_segcore.SaveBinLog(ctx,
suite.collectionID, suite.collectionID,
suite.partitionID, suite.partitionID,
suite.segmentID, suite.segmentID,
@ -101,15 +100,19 @@ func (suite *ReduceSuite) SetupTest() {
suite.chunkManager, suite.chunkManager,
) )
suite.Require().NoError(err) suite.Require().NoError(err)
for _, binlog := range binlogs { req := &segcore.LoadFieldDataRequest{
err = suite.segment.(*LocalSegment).LoadFieldData(ctx, binlog.FieldID, int64(msgLength), binlog) RowCount: int64(msgLength),
suite.Require().NoError(err)
} }
for _, binlog := range binlogs {
req.Fields = append(req.Fields, segcore.LoadFieldDataInfo{Field: binlog})
}
_, err = suite.segment.LoadFieldData(ctx, req)
suite.Require().NoError(err)
} }
func (suite *ReduceSuite) TearDownTest() { func (suite *ReduceSuite) TearDownTest() {
suite.segment.Release(context.Background()) suite.segment.Release()
DeleteCollection(suite.collection) suite.collection.Release()
ctx := context.Background() ctx := context.Background()
suite.chunkManager.RemoveWithPrefix(ctx, suite.rootPath) suite.chunkManager.RemoveWithPrefix(ctx, suite.rootPath)
} }
@ -118,7 +121,7 @@ func (suite *ReduceSuite) TestReduceParseSliceInfo() {
originNQs := []int64{2, 3, 2} originNQs := []int64{2, 3, 2}
originTopKs := []int64{10, 5, 20} originTopKs := []int64{10, 5, 20}
nqPerSlice := int64(2) nqPerSlice := int64(2)
sInfo := ParseSliceInfo(originNQs, originTopKs, nqPerSlice) sInfo := segcore.ParseSliceInfo(originNQs, originTopKs, nqPerSlice)
expectedSliceNQs := []int64{2, 2, 1, 2} expectedSliceNQs := []int64{2, 2, 1, 2}
expectedSliceTopKs := []int64{10, 5, 5, 20} expectedSliceTopKs := []int64{10, 5, 5, 20}
@ -130,7 +133,7 @@ func (suite *ReduceSuite) TestReduceAllFunc() {
nq := int64(10) nq := int64(10)
// TODO: replace below by genPlaceholderGroup(nq) // TODO: replace below by genPlaceholderGroup(nq)
vec := testutils.GenerateFloatVectors(1, defaultDim) vec := testutils.GenerateFloatVectors(1, mock_segcore.DefaultDim)
var searchRawData []byte var searchRawData []byte
for i, ele := range vec { for i, ele := range vec {
buf := make([]byte, 4) buf := make([]byte, 4)
@ -167,35 +170,73 @@ func (suite *ReduceSuite) TestReduceAllFunc() {
> >
placeholder_tag: "$0" placeholder_tag: "$0"
>` >`
var planpb planpb.PlanNode var planNode planpb.PlanNode
// proto.UnmarshalText(planStr, &planpb) // proto.UnmarshalText(planStr, &planpb)
prototext.Unmarshal([]byte(planStr), &planpb) prototext.Unmarshal([]byte(planStr), &planNode)
serializedPlan, err := proto.Marshal(&planpb) serializedPlan, err := proto.Marshal(&planNode)
suite.NoError(err) suite.NoError(err)
plan, err := createSearchPlanByExpr(context.Background(), suite.collection, serializedPlan) searchReq, err := segcore.NewSearchRequest(suite.collection, &querypb.SearchRequest{
suite.NoError(err) Req: &internalpb.SearchRequest{
searchReq, err := parseSearchRequest(context.Background(), plan, placeGroupByte) SerializedExprPlan: serializedPlan,
searchReq.mvccTimestamp = typeutil.MaxTimestamp MvccTimestamp: typeutil.MaxTimestamp,
},
}, placeGroupByte)
suite.NoError(err) suite.NoError(err)
defer searchReq.Delete() defer searchReq.Delete()
searchResult, err := suite.segment.Search(context.Background(), searchReq) searchResult, err := suite.segment.Search(context.Background(), searchReq)
suite.NoError(err) suite.NoError(err)
err = checkSearchResult(context.Background(), nq, plan, searchResult) err = mock_segcore.CheckSearchResult(context.Background(), nq, searchReq.Plan(), searchResult)
suite.NoError(err) suite.NoError(err)
// Test Illegal Query
retrievePlan, err := segcore.NewRetrievePlan(
suite.collection,
[]byte(fmt.Sprintf("%d > 100", mock_segcore.RowIDField.ID)),
typeutil.MaxTimestamp,
0)
suite.Error(err)
suite.Nil(retrievePlan)
plan := planpb.PlanNode{
Node: &planpb.PlanNode_Query{
Query: &planpb.QueryPlanNode{
IsCount: true,
},
},
}
expr, err := proto.Marshal(&plan)
suite.NoError(err)
retrievePlan, err = segcore.NewRetrievePlan(
suite.collection,
expr,
typeutil.MaxTimestamp,
0)
suite.NotNil(retrievePlan)
suite.NoError(err)
retrieveResult, err := suite.segment.Retrieve(context.Background(), retrievePlan)
suite.NotNil(retrieveResult)
suite.NoError(err)
result, err := retrieveResult.GetResult()
suite.NoError(err)
suite.NotNil(result)
suite.Equal(int64(100), result.AllRetrieveCount)
retrieveResult.Release()
} }
func (suite *ReduceSuite) TestReduceInvalid() { func (suite *ReduceSuite) TestReduceInvalid() {
plan := &SearchPlan{} plan := &segcore.SearchPlan{}
_, err := ReduceSearchResultsAndFillData(context.Background(), plan, nil, 1, nil, nil) _, err := segcore.ReduceSearchResultsAndFillData(context.Background(), plan, nil, 1, nil, nil)
suite.Error(err) suite.Error(err)
searchReq, err := genSearchPlanAndRequests(suite.collection, []int64{suite.segmentID}, IndexHNSW, 10) searchReq, err := mock_segcore.GenSearchPlanAndRequests(suite.collection, []int64{suite.segmentID}, mock_segcore.IndexHNSW, 10)
suite.NoError(err) suite.NoError(err)
searchResults := make([]*SearchResult, 0) searchResults := make([]*segcore.SearchResult, 0)
searchResults = append(searchResults, nil) searchResults = append(searchResults, nil)
_, err = ReduceSearchResultsAndFillData(context.Background(), searchReq.plan, searchResults, 1, []int64{10}, []int64{10}) _, err = segcore.ReduceSearchResultsAndFillData(context.Background(), searchReq.Plan(), searchResults, 1, []int64{10}, []int64{10})
suite.Error(err) suite.Error(err)
} }

View File

@ -0,0 +1,100 @@
package segcore
/*
#cgo pkg-config: milvus_core
#include "segcore/load_field_data_c.h"
*/
import "C"
import (
"unsafe"
"github.com/cockroachdb/errors"
"github.com/milvus-io/milvus/internal/proto/datapb"
"github.com/milvus-io/milvus/internal/proto/segcorepb"
"github.com/milvus-io/milvus/internal/storage"
"github.com/milvus-io/milvus/pkg/util/typeutil"
)
type RetrievePlanWithOffsets struct {
*RetrievePlan
Offsets []int64
}
type InsertRequest struct {
RowIDs []int64
Timestamps []typeutil.Timestamp
Record *segcorepb.InsertRecord
}
type DeleteRequest struct {
PrimaryKeys storage.PrimaryKeys
Timestamps []typeutil.Timestamp
}
type LoadFieldDataRequest struct {
Fields []LoadFieldDataInfo
MMapDir string
RowCount int64
}
type LoadFieldDataInfo struct {
Field *datapb.FieldBinlog
EnableMMap bool
}
func (req *LoadFieldDataRequest) getCLoadFieldDataRequest() (result *cLoadFieldDataRequest, err error) {
var cLoadFieldDataInfo C.CLoadFieldDataInfo
status := C.NewLoadFieldDataInfo(&cLoadFieldDataInfo)
if err := ConsumeCStatusIntoError(&status); err != nil {
return nil, errors.Wrap(err, "NewLoadFieldDataInfo failed")
}
defer func() {
if err != nil {
C.DeleteLoadFieldDataInfo(cLoadFieldDataInfo)
}
}()
rowCount := C.int64_t(req.RowCount)
for _, field := range req.Fields {
cFieldID := C.int64_t(field.Field.GetFieldID())
status = C.AppendLoadFieldInfo(cLoadFieldDataInfo, cFieldID, rowCount)
if err := ConsumeCStatusIntoError(&status); err != nil {
return nil, errors.Wrapf(err, "AppendLoadFieldInfo failed at fieldID, %d", field.Field.GetFieldID())
}
for _, binlog := range field.Field.Binlogs {
cEntriesNum := C.int64_t(binlog.GetEntriesNum())
cFile := C.CString(binlog.GetLogPath())
defer C.free(unsafe.Pointer(cFile))
status = C.AppendLoadFieldDataPath(cLoadFieldDataInfo, cFieldID, cEntriesNum, cFile)
if err := ConsumeCStatusIntoError(&status); err != nil {
return nil, errors.Wrapf(err, "AppendLoadFieldDataPath failed at binlog, %d, %s", field.Field.GetFieldID(), binlog.GetLogPath())
}
}
C.EnableMmap(cLoadFieldDataInfo, cFieldID, C.bool(field.EnableMMap))
}
if len(req.MMapDir) > 0 {
mmapDir := C.CString(req.MMapDir)
defer C.free(unsafe.Pointer(mmapDir))
C.AppendMMapDirPath(cLoadFieldDataInfo, mmapDir)
}
return &cLoadFieldDataRequest{
cLoadFieldDataInfo: cLoadFieldDataInfo,
}, nil
}
type cLoadFieldDataRequest struct {
cLoadFieldDataInfo C.CLoadFieldDataInfo
}
func (req *cLoadFieldDataRequest) Release() {
C.DeleteLoadFieldDataInfo(req.cLoadFieldDataInfo)
}
type AddFieldDataInfoRequest = LoadFieldDataRequest
type AddFieldDataInfoResult struct{}

View File

@ -0,0 +1,34 @@
package segcore
import (
"testing"
"github.com/stretchr/testify/assert"
"github.com/milvus-io/milvus/internal/proto/datapb"
)
func TestLoadFieldDataRequest(t *testing.T) {
req := &LoadFieldDataRequest{
Fields: []LoadFieldDataInfo{{
Field: &datapb.FieldBinlog{
FieldID: 1,
Binlogs: []*datapb.Binlog{
{
EntriesNum: 100,
LogPath: "1",
}, {
EntriesNum: 101,
LogPath: "2",
},
},
},
}},
RowCount: 100,
MMapDir: "1234567890",
}
creq, err := req.getCLoadFieldDataRequest()
assert.NoError(t, err)
assert.NotNil(t, creq)
creq.Release()
}

View File

@ -0,0 +1,47 @@
package segcore
/*
#cgo pkg-config: milvus_core
#include "segcore/plan_c.h"
#include "segcore/reduce_c.h"
*/
import "C"
import (
"github.com/milvus-io/milvus/internal/proto/segcorepb"
)
type SearchResult struct {
cSearchResult C.CSearchResult
}
func (r *SearchResult) Release() {
C.DeleteSearchResult(r.cSearchResult)
r.cSearchResult = nil
}
type RetrieveResult struct {
cRetrieveResult *C.CRetrieveResult
}
func (r *RetrieveResult) GetResult() (*segcorepb.RetrieveResults, error) {
retrieveResult := new(segcorepb.RetrieveResults)
if err := unmarshalCProto(r.cRetrieveResult, retrieveResult); err != nil {
return nil, err
}
return retrieveResult, nil
}
func (r *RetrieveResult) Release() {
C.DeleteRetrieveResult(r.cRetrieveResult)
r.cRetrieveResult = nil
}
type InsertResult struct {
InsertedRows int64
}
type DeleteResult struct{}
type LoadFieldDataResult struct{}

View File

@ -0,0 +1,24 @@
package segcore
/*
#cgo pkg-config: milvus_core
#include "segcore/segcore_init_c.h"
*/
import "C"
// IndexEngineInfo contains all the information about the index engine.
type IndexEngineInfo struct {
MinIndexVersion int32
CurrentIndexVersion int32
}
// GetIndexEngineInfo returns the minimal and current version of the index engine.
func GetIndexEngineInfo() IndexEngineInfo {
cMinimal, cCurrent := C.GetMinimalIndexVersion(), C.GetCurrentIndexVersion()
return IndexEngineInfo{
MinIndexVersion: int32(cMinimal),
CurrentIndexVersion: int32(cCurrent),
}
}

View File

@ -0,0 +1,13 @@
package segcore
import (
"testing"
"github.com/stretchr/testify/assert"
)
func TestGetIndexEngineInfo(t *testing.T) {
r := GetIndexEngineInfo()
assert.NotZero(t, r.CurrentIndexVersion)
assert.Zero(t, r.MinIndexVersion)
}

View File

@ -0,0 +1,293 @@
package segcore
/*
#cgo pkg-config: milvus_core
#include "common/type_c.h"
#include "futures/future_c.h"
#include "segcore/collection_c.h"
#include "segcore/plan_c.h"
#include "segcore/reduce_c.h"
*/
import "C"
import (
"context"
"fmt"
"runtime"
"unsafe"
"github.com/cockroachdb/errors"
"google.golang.org/protobuf/proto"
"github.com/milvus-io/milvus-proto/go-api/v2/commonpb"
"github.com/milvus-io/milvus/internal/storage"
"github.com/milvus-io/milvus/internal/util/cgo"
"github.com/milvus-io/milvus/pkg/util/merr"
)
const (
SegmentTypeGrowing SegmentType = commonpb.SegmentState_Growing
SegmentTypeSealed SegmentType = commonpb.SegmentState_Sealed
)
type (
SegmentType = commonpb.SegmentState
CSegmentInterface C.CSegmentInterface
)
// CreateCSegmentRequest is a request to create a segment.
type CreateCSegmentRequest struct {
Collection *CCollection
SegmentID int64
SegmentType SegmentType
IsSorted bool
EnableChunked bool
}
func (req *CreateCSegmentRequest) getCSegmentType() C.SegmentType {
var segmentType C.SegmentType
switch req.SegmentType {
case SegmentTypeGrowing:
segmentType = C.Growing
case SegmentTypeSealed:
if req.EnableChunked {
segmentType = C.ChunkedSealed
break
}
segmentType = C.Sealed
default:
panic(fmt.Sprintf("invalid segment type: %d", req.SegmentType))
}
return segmentType
}
// CreateCSegment creates a segment from a CreateCSegmentRequest.
func CreateCSegment(req *CreateCSegmentRequest) (CSegment, error) {
var ptr C.CSegmentInterface
status := C.NewSegment(req.Collection.rawPointer(), req.getCSegmentType(), C.int64_t(req.SegmentID), &ptr, C.bool(req.IsSorted))
if err := ConsumeCStatusIntoError(&status); err != nil {
return nil, err
}
return &cSegmentImpl{id: req.SegmentID, ptr: ptr}, nil
}
// cSegmentImpl is a wrapper for cSegmentImplInterface.
type cSegmentImpl struct {
id int64
ptr C.CSegmentInterface
}
// ID returns the ID of the segment.
func (s *cSegmentImpl) ID() int64 {
return s.id
}
// RawPointer returns the raw pointer of the segment.
func (s *cSegmentImpl) RawPointer() CSegmentInterface {
return CSegmentInterface(s.ptr)
}
// RowNum returns the number of rows in the segment.
func (s *cSegmentImpl) RowNum() int64 {
rowCount := C.GetRealCount(s.ptr)
return int64(rowCount)
}
// MemSize returns the memory size of the segment.
func (s *cSegmentImpl) MemSize() int64 {
cMemSize := C.GetMemoryUsageInBytes(s.ptr)
return int64(cMemSize)
}
// HasRawData checks if the segment has raw data.
func (s *cSegmentImpl) HasRawData(fieldID int64) bool {
ret := C.HasRawData(s.ptr, C.int64_t(fieldID))
return bool(ret)
}
// Search requests a search on the segment.
func (s *cSegmentImpl) Search(ctx context.Context, searchReq *SearchRequest) (*SearchResult, error) {
traceCtx := ParseCTraceContext(ctx)
defer runtime.KeepAlive(traceCtx)
defer runtime.KeepAlive(searchReq)
future := cgo.Async(ctx,
func() cgo.CFuturePtr {
return (cgo.CFuturePtr)(C.AsyncSearch(
traceCtx.ctx,
s.ptr,
searchReq.plan.cSearchPlan,
searchReq.cPlaceholderGroup,
C.uint64_t(searchReq.mvccTimestamp),
))
},
cgo.WithName("search"),
)
defer future.Release()
result, err := future.BlockAndLeakyGet()
if err != nil {
return nil, err
}
return &SearchResult{cSearchResult: (C.CSearchResult)(result)}, nil
}
// Retrieve retrieves entities from the segment.
func (s *cSegmentImpl) Retrieve(ctx context.Context, plan *RetrievePlan) (*RetrieveResult, error) {
traceCtx := ParseCTraceContext(ctx)
defer runtime.KeepAlive(traceCtx)
defer runtime.KeepAlive(plan)
future := cgo.Async(
ctx,
func() cgo.CFuturePtr {
return (cgo.CFuturePtr)(C.AsyncRetrieve(
traceCtx.ctx,
s.ptr,
plan.cRetrievePlan,
C.uint64_t(plan.Timestamp),
C.int64_t(plan.maxLimitSize),
C.bool(plan.ignoreNonPk),
))
},
cgo.WithName("retrieve"),
)
defer future.Release()
result, err := future.BlockAndLeakyGet()
if err != nil {
return nil, err
}
return &RetrieveResult{cRetrieveResult: (*C.CRetrieveResult)(result)}, nil
}
// RetrieveByOffsets retrieves entities from the segment by offsets.
func (s *cSegmentImpl) RetrieveByOffsets(ctx context.Context, plan *RetrievePlanWithOffsets) (*RetrieveResult, error) {
if len(plan.Offsets) == 0 {
return nil, merr.WrapErrParameterInvalid("segment offsets", "empty offsets")
}
traceCtx := ParseCTraceContext(ctx)
defer runtime.KeepAlive(traceCtx)
defer runtime.KeepAlive(plan)
defer runtime.KeepAlive(plan.Offsets)
future := cgo.Async(
ctx,
func() cgo.CFuturePtr {
return (cgo.CFuturePtr)(C.AsyncRetrieveByOffsets(
traceCtx.ctx,
s.ptr,
plan.cRetrievePlan,
(*C.int64_t)(unsafe.Pointer(&plan.Offsets[0])),
C.int64_t(len(plan.Offsets)),
))
},
cgo.WithName("retrieve-by-offsets"),
)
defer future.Release()
result, err := future.BlockAndLeakyGet()
if err != nil {
return nil, err
}
return &RetrieveResult{cRetrieveResult: (*C.CRetrieveResult)(result)}, nil
}
// Insert inserts entities into the segment.
func (s *cSegmentImpl) Insert(ctx context.Context, request *InsertRequest) (*InsertResult, error) {
offset, err := s.preInsert(len(request.RowIDs))
if err != nil {
return nil, err
}
insertRecordBlob, err := proto.Marshal(request.Record)
if err != nil {
return nil, fmt.Errorf("failed to marshal insert record: %s", err)
}
numOfRow := len(request.RowIDs)
cOffset := C.int64_t(offset)
cNumOfRows := C.int64_t(numOfRow)
cEntityIDsPtr := (*C.int64_t)(&(request.RowIDs)[0])
cTimestampsPtr := (*C.uint64_t)(&(request.Timestamps)[0])
status := C.Insert(s.ptr,
cOffset,
cNumOfRows,
cEntityIDsPtr,
cTimestampsPtr,
(*C.uint8_t)(unsafe.Pointer(&insertRecordBlob[0])),
(C.uint64_t)(len(insertRecordBlob)),
)
return &InsertResult{InsertedRows: int64(numOfRow)}, ConsumeCStatusIntoError(&status)
}
func (s *cSegmentImpl) preInsert(numOfRecords int) (int64, error) {
var offset int64
cOffset := (*C.int64_t)(&offset)
status := C.PreInsert(s.ptr, C.int64_t(int64(numOfRecords)), cOffset)
if err := ConsumeCStatusIntoError(&status); err != nil {
return 0, err
}
return offset, nil
}
// Delete deletes entities from the segment.
func (s *cSegmentImpl) Delete(ctx context.Context, request *DeleteRequest) (*DeleteResult, error) {
cOffset := C.int64_t(0) // depre
cSize := C.int64_t(request.PrimaryKeys.Len())
cTimestampsPtr := (*C.uint64_t)(&(request.Timestamps)[0])
ids, err := storage.ParsePrimaryKeysBatch2IDs(request.PrimaryKeys)
if err != nil {
return nil, err
}
dataBlob, err := proto.Marshal(ids)
if err != nil {
return nil, fmt.Errorf("failed to marshal ids: %s", err)
}
status := C.Delete(s.ptr,
cOffset,
cSize,
(*C.uint8_t)(unsafe.Pointer(&dataBlob[0])),
(C.uint64_t)(len(dataBlob)),
cTimestampsPtr,
)
return &DeleteResult{}, ConsumeCStatusIntoError(&status)
}
// LoadFieldData loads field data into the segment.
func (s *cSegmentImpl) LoadFieldData(ctx context.Context, request *LoadFieldDataRequest) (*LoadFieldDataResult, error) {
creq, err := request.getCLoadFieldDataRequest()
if err != nil {
return nil, err
}
defer creq.Release()
status := C.LoadFieldData(s.ptr, creq.cLoadFieldDataInfo)
if err := ConsumeCStatusIntoError(&status); err != nil {
return nil, errors.Wrap(err, "failed to load field data")
}
return &LoadFieldDataResult{}, nil
}
// AddFieldDataInfo adds field data info into the segment.
func (s *cSegmentImpl) AddFieldDataInfo(ctx context.Context, request *AddFieldDataInfoRequest) (*AddFieldDataInfoResult, error) {
creq, err := request.getCLoadFieldDataRequest()
if err != nil {
return nil, err
}
defer creq.Release()
status := C.AddFieldDataInfoForSealed(s.ptr, creq.cLoadFieldDataInfo)
if err := ConsumeCStatusIntoError(&status); err != nil {
return nil, errors.Wrap(err, "failed to add field data info")
}
return &AddFieldDataInfoResult{}, nil
}
// Release releases the segment.
func (s *cSegmentImpl) Release() {
C.DeleteSegment(s.ptr)
}

View File

@ -0,0 +1,74 @@
package segcore
/*
#cgo pkg-config: milvus_core
#include "common/type_c.h"
*/
import "C"
import "context"
// CSegment is the interface of a segcore segment.
// TODO: We should separate the interface of CGrowingSegment and CSealedSegment,
// Because they have different implementations, GrowingSegment will only be used at streamingnode, SealedSegment will only be used at querynode.
// But currently, we just use the same interface to represent them to keep compatible with querynode LocalSegment.
type CSegment interface {
GrowingSegment
SealedSegment
}
// GrowingSegment is the interface of a growing segment.
type GrowingSegment interface {
basicSegmentMethodSet
// Insert inserts data into the segment.
Insert(ctx context.Context, request *InsertRequest) (*InsertResult, error)
}
// SealedSegment is the interface of a sealed segment.
type SealedSegment interface {
basicSegmentMethodSet
// LoadFieldData loads field data into the segment.
LoadFieldData(ctx context.Context, request *LoadFieldDataRequest) (*LoadFieldDataResult, error)
// AddFieldDataInfo adds field data info into the segment.
AddFieldDataInfo(ctx context.Context, request *AddFieldDataInfoRequest) (*AddFieldDataInfoResult, error)
}
// basicSegmentMethodSet is the basic method set of a segment.
type basicSegmentMethodSet interface {
// ID returns the ID of the segment.
ID() int64
// RawPointer returns the raw pointer of the segment.
// TODO: should be removed in future.
RawPointer() CSegmentInterface
// RawPointer returns the raw pointer of the segment.
RowNum() int64
// MemSize returns the memory size of the segment.
MemSize() int64
// HasRawData checks if the segment has raw data.
HasRawData(fieldID int64) bool
// Search requests a search on the segment.
Search(ctx context.Context, searchReq *SearchRequest) (*SearchResult, error)
// Retrieve retrieves entities from the segment.
Retrieve(ctx context.Context, plan *RetrievePlan) (*RetrieveResult, error)
// RetrieveByOffsets retrieves entities from the segment by offsets.
RetrieveByOffsets(ctx context.Context, plan *RetrievePlanWithOffsets) (*RetrieveResult, error)
// Delete deletes data from the segment.
Delete(ctx context.Context, request *DeleteRequest) (*DeleteResult, error)
// Release releases the segment.
Release()
}

View File

@ -0,0 +1,137 @@
package segcore_test
import (
"context"
"path/filepath"
"testing"
"github.com/stretchr/testify/assert"
"google.golang.org/protobuf/proto"
"github.com/milvus-io/milvus-proto/go-api/v2/schemapb"
"github.com/milvus-io/milvus/internal/mocks/util/mock_segcore"
"github.com/milvus-io/milvus/internal/proto/planpb"
"github.com/milvus-io/milvus/internal/proto/segcorepb"
"github.com/milvus-io/milvus/internal/storage"
"github.com/milvus-io/milvus/internal/util/initcore"
"github.com/milvus-io/milvus/internal/util/segcore"
"github.com/milvus-io/milvus/pkg/util/paramtable"
"github.com/milvus-io/milvus/pkg/util/typeutil"
)
func TestGrowingSegment(t *testing.T) {
paramtable.Init()
localDataRootPath := filepath.Join(paramtable.Get().LocalStorageCfg.Path.GetValue(), typeutil.QueryNodeRole)
initcore.InitLocalChunkManager(localDataRootPath)
err := initcore.InitMmapManager(paramtable.Get())
assert.NoError(t, err)
collectionID := int64(100)
segmentID := int64(100)
schema := mock_segcore.GenTestCollectionSchema("test-reduce", schemapb.DataType_Int64, true)
collection, err := segcore.CreateCCollection(&segcore.CreateCCollectionRequest{
CollectionID: collectionID,
Schema: schema,
IndexMeta: mock_segcore.GenTestIndexMeta(collectionID, schema),
})
assert.NoError(t, err)
assert.NotNil(t, collection)
defer collection.Release()
segment, err := segcore.CreateCSegment(&segcore.CreateCSegmentRequest{
Collection: collection,
SegmentID: segmentID,
SegmentType: segcore.SegmentTypeGrowing,
IsSorted: false,
})
assert.NoError(t, err)
assert.NotNil(t, segment)
defer segment.Release()
assert.Equal(t, segmentID, segment.ID())
assert.Equal(t, int64(0), segment.RowNum())
assert.Zero(t, segment.MemSize())
assert.True(t, segment.HasRawData(0))
assertEqualCount(t, collection, segment, 0)
insertMsg, err := mock_segcore.GenInsertMsg(collection, 1, segmentID, 100)
assert.NoError(t, err)
insertResult, err := segment.Insert(context.Background(), &segcore.InsertRequest{
RowIDs: insertMsg.RowIDs,
Timestamps: insertMsg.Timestamps,
Record: &segcorepb.InsertRecord{
FieldsData: insertMsg.FieldsData,
NumRows: int64(len(insertMsg.RowIDs)),
},
})
assert.NoError(t, err)
assert.NotNil(t, insertResult)
assert.Equal(t, int64(100), insertResult.InsertedRows)
assert.Equal(t, int64(100), segment.RowNum())
assertEqualCount(t, collection, segment, 100)
pk := storage.NewInt64PrimaryKeys(1)
pk.Append(storage.NewInt64PrimaryKey(10))
deleteResult, err := segment.Delete(context.Background(), &segcore.DeleteRequest{
PrimaryKeys: pk,
Timestamps: []typeutil.Timestamp{
1000,
},
})
assert.NoError(t, err)
assert.NotNil(t, deleteResult)
assert.Equal(t, int64(99), segment.RowNum())
}
func assertEqualCount(
t *testing.T,
collection *segcore.CCollection,
segment segcore.CSegment,
count int64,
) {
plan := planpb.PlanNode{
Node: &planpb.PlanNode_Query{
Query: &planpb.QueryPlanNode{
IsCount: true,
},
},
}
expr, err := proto.Marshal(&plan)
assert.NoError(t, err)
retrievePlan, err := segcore.NewRetrievePlan(
collection,
expr,
typeutil.MaxTimestamp,
100)
defer retrievePlan.Delete()
assert.True(t, retrievePlan.ShouldIgnoreNonPk())
assert.False(t, retrievePlan.IsIgnoreNonPk())
retrievePlan.SetIgnoreNonPk(true)
assert.True(t, retrievePlan.IsIgnoreNonPk())
assert.NotZero(t, retrievePlan.MsgID())
assert.NotNil(t, retrievePlan)
assert.NoError(t, err)
retrieveResult, err := segment.Retrieve(context.Background(), retrievePlan)
assert.NotNil(t, retrieveResult)
assert.NoError(t, err)
result, err := retrieveResult.GetResult()
assert.NoError(t, err)
assert.NotNil(t, result)
assert.Equal(t, count, result.AllRetrieveCount)
retrieveResult.Release()
retrieveResult2, err := segment.RetrieveByOffsets(context.Background(), &segcore.RetrievePlanWithOffsets{
RetrievePlan: retrievePlan,
Offsets: []int64{0, 1, 2, 3, 4},
})
assert.NoError(t, err)
assert.NotNil(t, retrieveResult2)
retrieveResult2.Release()
}

View File

@ -0,0 +1,56 @@
// Licensed to the LF AI & Data foundation under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package segcore
/*
#cgo pkg-config: milvus_core
#include "segcore/segment_c.h"
*/
import "C"
import (
"context"
"unsafe"
"go.opentelemetry.io/otel/trace"
)
// CTraceContext is the wrapper for `C.CTraceContext`
// it stores the internal C.CTraceContext and
type CTraceContext struct {
traceID trace.TraceID
spanID trace.SpanID
ctx C.CTraceContext
}
// ParseCTraceContext parses tracing span and convert it into `C.CTraceContext`.
func ParseCTraceContext(ctx context.Context) *CTraceContext {
span := trace.SpanFromContext(ctx)
cctx := &CTraceContext{
traceID: span.SpanContext().TraceID(),
spanID: span.SpanContext().SpanID(),
}
cctx.ctx = C.CTraceContext{
traceID: (*C.uint8_t)(unsafe.Pointer(&cctx.traceID[0])),
spanID: (*C.uint8_t)(unsafe.Pointer(&cctx.spanID[0])),
traceFlags: (C.uint8_t)(span.SpanContext().TraceFlags()),
}
return cctx
}