Add unittest for loadIndex service

Signed-off-by: xige-16 <xi.ge@zilliz.com>
This commit is contained in:
xige-16 2020-12-26 14:16:51 +08:00 committed by yefu.chen
parent 2031c54746
commit d599407e2b
9 changed files with 194 additions and 47 deletions

View File

@ -32,6 +32,12 @@ NewLoadIndexInfo(CLoadIndexInfo* c_load_index_info) {
} }
} }
void
DeleteLoadIndexInfo(CLoadIndexInfo c_load_index_info) {
auto info = (LoadIndexInfo*)c_load_index_info;
delete info;
}
CStatus CStatus
AppendIndexParam(CLoadIndexInfo c_load_index_info, const char* c_index_key, const char* c_index_value) { AppendIndexParam(CLoadIndexInfo c_load_index_info, const char* c_index_key, const char* c_index_value) {
try { try {

View File

@ -25,6 +25,9 @@ typedef void* CBinarySet;
CStatus CStatus
NewLoadIndexInfo(CLoadIndexInfo* c_load_index_info); NewLoadIndexInfo(CLoadIndexInfo* c_load_index_info);
void
DeleteLoadIndexInfo(CLoadIndexInfo c_load_index_info);
CStatus CStatus
AppendIndexParam(CLoadIndexInfo c_load_index_info, const char* index_key, const char* index_value); AppendIndexParam(CLoadIndexInfo c_load_index_info, const char* index_key, const char* index_value);

View File

@ -176,8 +176,9 @@ FillTargetEntry(CSegmentBase c_segment, CPlan c_plan, CQueryResult c_result) {
CStatus CStatus
UpdateSegmentIndex(CSegmentBase c_segment, CLoadIndexInfo c_load_index_info) { UpdateSegmentIndex(CSegmentBase c_segment, CLoadIndexInfo c_load_index_info) {
auto load_index_info = (LoadIndexInfo*)c_load_index_info;
try { try {
auto segment = (milvus::segcore::SegmentBase*)c_segment;
auto load_index_info = (LoadIndexInfo*)c_load_index_info;
auto status = CStatus(); auto status = CStatus();
status.error_code = Success; status.error_code = Success;
status.error_msg = ""; status.error_msg = "";
@ -189,7 +190,6 @@ UpdateSegmentIndex(CSegmentBase c_segment, CLoadIndexInfo c_load_index_info) {
return status; return status;
} }
} }
////////////////////////////////////////////////////////////////// //////////////////////////////////////////////////////////////////
int int

View File

@ -685,6 +685,49 @@ TEST(CApiTest, Reduce) {
DeleteSegment(segment); DeleteSegment(segment);
} }
TEST(CApiTest, LoadIndexInfo) {
// generator index
constexpr auto DIM = 16;
constexpr auto K = 10;
auto N = 1024 * 10;
auto [raw_data, timestamps, uids] = generate_data(N);
auto indexing = std::make_shared<milvus::knowhere::IVFPQ>();
auto conf = milvus::knowhere::Config{{milvus::knowhere::meta::DIM, DIM},
{milvus::knowhere::meta::TOPK, K},
{milvus::knowhere::IndexParams::nlist, 100},
{milvus::knowhere::IndexParams::nprobe, 4},
{milvus::knowhere::IndexParams::m, 4},
{milvus::knowhere::IndexParams::nbits, 8},
{milvus::knowhere::Metric::TYPE, milvus::knowhere::Metric::L2},
{milvus::knowhere::meta::DEVICEID, 0}};
auto database = milvus::knowhere::GenDataset(N, DIM, raw_data.data());
indexing->Train(database, conf);
indexing->AddWithoutIds(database, conf);
EXPECT_EQ(indexing->Count(), N);
EXPECT_EQ(indexing->Dim(), DIM);
auto binary_set = indexing->Serialize(conf);
CBinarySet c_binary_set = (CBinarySet)&binary_set;
void* c_load_index_info = nullptr;
auto status = NewLoadIndexInfo(&c_load_index_info);
assert(status.error_code == Success);
std::string index_param_key1 = "index_type";
std::string index_param_value1 = "IVF_PQ";
status = AppendIndexParam(c_load_index_info, index_param_key1.data(), index_param_value1.data());
std::string index_param_key2 = "index_mode";
std::string index_param_value2 = "cpu";
status = AppendIndexParam(c_load_index_info, index_param_key2.data(), index_param_value2.data());
assert(status.error_code == Success);
std::string field_name = "field0";
status = AppendFieldInfo(c_load_index_info, field_name.data(), 0);
assert(status.error_code == Success);
status = AppendIndex(c_load_index_info, c_binary_set);
assert(status.error_code == Success);
DeleteLoadIndexInfo(c_load_index_info);
}
TEST(CApiTest, LoadIndex_Search) { TEST(CApiTest, LoadIndex_Search) {
// generator index // generator index
constexpr auto DIM = 16; constexpr auto DIM = 16;

View File

@ -4,6 +4,7 @@ import (
"context" "context"
"github.com/zilliztech/milvus-distributed/internal/msgstream" "github.com/zilliztech/milvus-distributed/internal/msgstream"
"github.com/zilliztech/milvus-distributed/internal/proto/commonpb"
internalPb "github.com/zilliztech/milvus-distributed/internal/proto/internalpb" internalPb "github.com/zilliztech/milvus-distributed/internal/proto/internalpb"
) )
@ -21,18 +22,28 @@ func NewLoadIndexClient(ctx context.Context, pulsarAddress string, loadIndexChan
} }
} }
func (lic *LoadIndexClient) LoadIndex(indexPaths []string, segmentID int64, fieldID int64, indexParam map[string]string) error { func (lic *LoadIndexClient) LoadIndex(indexPaths []string, segmentID int64, fieldID int64, fieldName string, indexParams map[string]string) error {
// TODO:: add indexParam to proto
baseMsg := msgstream.BaseMsg{ baseMsg := msgstream.BaseMsg{
BeginTimestamp: 0, BeginTimestamp: 0,
EndTimestamp: 0, EndTimestamp: 0,
HashValues: []uint32{0}, HashValues: []uint32{0},
} }
var indexParamsKV []*commonpb.KeyValuePair
for indexParam := range indexParams {
indexParamsKV = append(indexParamsKV, &commonpb.KeyValuePair{
Key: indexParam,
Value: indexParams[indexParam],
})
}
loadIndexRequest := internalPb.LoadIndex{ loadIndexRequest := internalPb.LoadIndex{
MsgType: internalPb.MsgType_kLoadIndex, MsgType: internalPb.MsgType_kLoadIndex,
SegmentID: segmentID, SegmentID: segmentID,
FieldName: fieldName,
FieldID: fieldID, FieldID: fieldID,
IndexPaths: indexPaths, IndexPaths: indexPaths,
IndexParams: indexParamsKV,
} }
loadIndexMsg := &msgstream.LoadIndexMsg{ loadIndexMsg := &msgstream.LoadIndexMsg{

View File

@ -18,7 +18,7 @@ type LoadIndexInfo struct {
cLoadIndexInfo C.CLoadIndexInfo cLoadIndexInfo C.CLoadIndexInfo
} }
func NewLoadIndexInfo() (*LoadIndexInfo, error) { func newLoadIndexInfo() (*LoadIndexInfo, error) {
var cLoadIndexInfo C.CLoadIndexInfo var cLoadIndexInfo C.CLoadIndexInfo
status := C.NewLoadIndexInfo(&cLoadIndexInfo) status := C.NewLoadIndexInfo(&cLoadIndexInfo)
errorCode := status.error_code errorCode := status.error_code
@ -31,7 +31,11 @@ func NewLoadIndexInfo() (*LoadIndexInfo, error) {
return &LoadIndexInfo{cLoadIndexInfo: cLoadIndexInfo}, nil return &LoadIndexInfo{cLoadIndexInfo: cLoadIndexInfo}, nil
} }
func (li *LoadIndexInfo) AppendIndexParam(indexKey string, indexValue string) error { func deleteLoadIndexInfo(info *LoadIndexInfo) {
C.DeleteLoadIndexInfo(info.cLoadIndexInfo)
}
func (li *LoadIndexInfo) appendIndexParam(indexKey string, indexValue string) error {
cIndexKey := C.CString(indexKey) cIndexKey := C.CString(indexKey)
cIndexValue := C.CString(indexValue) cIndexValue := C.CString(indexValue)
status := C.AppendIndexParam(li.cLoadIndexInfo, cIndexKey, cIndexValue) status := C.AppendIndexParam(li.cLoadIndexInfo, cIndexKey, cIndexValue)
@ -45,7 +49,7 @@ func (li *LoadIndexInfo) AppendIndexParam(indexKey string, indexValue string) er
return nil return nil
} }
func (li *LoadIndexInfo) AppendFieldInfo(fieldName string, fieldID int64) error { func (li *LoadIndexInfo) appendFieldInfo(fieldName string, fieldID int64) error {
cFieldName := C.CString(fieldName) cFieldName := C.CString(fieldName)
cFieldID := C.long(fieldID) cFieldID := C.long(fieldID)
status := C.AppendFieldInfo(li.cLoadIndexInfo, cFieldName, cFieldID) status := C.AppendFieldInfo(li.cLoadIndexInfo, cFieldName, cFieldID)
@ -59,7 +63,7 @@ func (li *LoadIndexInfo) AppendFieldInfo(fieldName string, fieldID int64) error
return nil return nil
} }
func (li *LoadIndexInfo) AppendIndex(bytesIndex [][]byte, indexKeys []string) error { func (li *LoadIndexInfo) appendIndex(bytesIndex [][]byte, indexKeys []string) error {
var cBinarySet C.CBinarySet var cBinarySet C.CBinarySet
status := C.NewBinarySet(&cBinarySet) status := C.NewBinarySet(&cBinarySet)

View File

@ -0,0 +1,36 @@
package querynode
import (
"testing"
"github.com/stretchr/testify/assert"
"github.com/zilliztech/milvus-distributed/internal/proto/commonpb"
)
func TestLoadIndexInfo(t *testing.T) {
indexParams := make([]*commonpb.KeyValuePair, 0)
indexParams = append(indexParams, &commonpb.KeyValuePair{
Key: "index_type",
Value: "IVF_PQ",
})
indexParams = append(indexParams, &commonpb.KeyValuePair{
Key: "index_mode",
Value: "cpu",
})
indexBytes := make([][]byte, 0)
indexValue := make([]byte, 10)
indexBytes = append(indexBytes, indexValue)
indexPaths := make([]string, 0)
indexPaths = append(indexPaths, "index-0")
loadIndexInfo, err := newLoadIndexInfo()
assert.Nil(t, err)
for _, indexParam := range indexParams {
loadIndexInfo.appendIndexParam(indexParam.Key, indexParam.Value)
}
loadIndexInfo.appendFieldInfo("field0", 0)
loadIndexInfo.appendIndex(indexBytes, indexPaths)
deleteLoadIndexInfo(loadIndexInfo)
}

View File

@ -107,16 +107,27 @@ func (lis *loadIndexService) start() {
log.Println("type assertion failed for LoadIndexMsg") log.Println("type assertion failed for LoadIndexMsg")
continue continue
} }
/* TODO: debug //// 1. use msg's index paths to get index bytes
// 1. use msg's index paths to get index bytes //var indexBuffer [][]byte
indexBuffer := lis.loadIndex(indexMsg.IndexPaths) //var err error
// 2. use index bytes and index path to update segment //fn := func() error {
err := lis.updateSegmentIndex(indexBuffer, indexMsg.IndexPaths, indexMsg.SegmentID) // indexBuffer, err = lis.loadIndex(indexMsg.IndexPaths)
if err != nil { // if err != nil {
log.Println(err) // return err
continue // }
} // return nil
*/ //}
//err = msgstream.Retry(5, time.Millisecond*200, fn)
//if err != nil {
// log.Println(err)
// continue
//}
//// 2. use index bytes and index path to update segment
//err = lis.updateSegmentIndex(indexBuffer, indexMsg)
//if err != nil {
// log.Println(err)
// continue
//}
//3. update segment index stats //3. update segment index stats
err := lis.updateSegmentIndexStats(indexMsg) err := lis.updateSegmentIndexStats(indexMsg)
if err != nil { if err != nil {
@ -216,7 +227,7 @@ func (lis *loadIndexService) updateSegmentIndexStats(indexMsg *msgstream.LoadInd
return nil return nil
} }
func (lis *loadIndexService) loadIndex(indexPath []string) [][]byte { func (lis *loadIndexService) loadIndex(indexPath []string) ([][]byte, error) {
index := make([][]byte, 0) index := make([][]byte, 0)
for _, path := range indexPath { for _, path := range indexPath {
@ -224,13 +235,12 @@ func (lis *loadIndexService) loadIndex(indexPath []string) [][]byte {
binarySetKey := filepath.Base(path) binarySetKey := filepath.Base(path)
indexPiece, err := (*lis.client).Load(binarySetKey) indexPiece, err := (*lis.client).Load(binarySetKey)
if err != nil { if err != nil {
log.Println(err) return nil, err
return nil
} }
index = append(index, []byte(indexPiece)) index = append(index, []byte(indexPiece))
} }
return index return index, nil
} }
func (lis *loadIndexService) updateSegmentIndex(bytesIndex [][]byte, loadIndexMsg *msgstream.LoadIndexMsg) error { func (lis *loadIndexService) updateSegmentIndex(bytesIndex [][]byte, loadIndexMsg *msgstream.LoadIndexMsg) error {
@ -239,21 +249,22 @@ func (lis *loadIndexService) updateSegmentIndex(bytesIndex [][]byte, loadIndexMs
return err return err
} }
loadIndexInfo, err := NewLoadIndexInfo() loadIndexInfo, err := newLoadIndexInfo()
defer deleteLoadIndexInfo(loadIndexInfo)
if err != nil { if err != nil {
return err return err
} }
err = loadIndexInfo.AppendFieldInfo(loadIndexMsg.FieldName, loadIndexMsg.FieldID) err = loadIndexInfo.appendFieldInfo(loadIndexMsg.FieldName, loadIndexMsg.FieldID)
if err != nil { if err != nil {
return err return err
} }
for _, indexParam := range loadIndexMsg.IndexParams { for _, indexParam := range loadIndexMsg.IndexParams {
err = loadIndexInfo.AppendIndexParam(indexParam.Key, indexParam.Value) err = loadIndexInfo.appendIndexParam(indexParam.Key, indexParam.Value)
if err != nil { if err != nil {
return err return err
} }
} }
err = loadIndexInfo.AppendIndex(bytesIndex, loadIndexMsg.IndexPaths) err = loadIndexInfo.appendIndex(bytesIndex, loadIndexMsg.IndexPaths)
if err != nil { if err != nil {
return err return err
} }

View File

@ -1,6 +1,7 @@
package querynode package querynode
import ( import (
"context"
"math" "math"
"math/rand" "math/rand"
"sort" "sort"
@ -11,8 +12,26 @@ import (
"github.com/zilliztech/milvus-distributed/internal/msgstream" "github.com/zilliztech/milvus-distributed/internal/msgstream"
"github.com/zilliztech/milvus-distributed/internal/proto/commonpb" "github.com/zilliztech/milvus-distributed/internal/proto/commonpb"
internalPb "github.com/zilliztech/milvus-distributed/internal/proto/internalpb" internalPb "github.com/zilliztech/milvus-distributed/internal/proto/internalpb"
"github.com/zilliztech/milvus-distributed/internal/querynode/client"
) )
func TestLoadIndexClient_LoadIndex(t *testing.T) {
pulsarURL := Params.PulsarAddress
loadIndexChannels := Params.LoadIndexChannelNames
loadIndexClient := client.NewLoadIndexClient(context.Background(), pulsarURL, loadIndexChannels)
loadIndexPath := "collection0-segment0-field0"
loadIndexPaths := make([]string, 0)
loadIndexPaths = append(loadIndexPaths, loadIndexPath)
indexParams := make(map[string]string)
indexParams["index_type"] = "IVF_PQ"
indexParams["index_mode"] = "cpu"
loadIndexClient.LoadIndex(loadIndexPaths, 0, 0, "field0", indexParams)
loadIndexClient.Close()
}
func TestLoadIndexService_PulsarAddress(t *testing.T) { func TestLoadIndexService_PulsarAddress(t *testing.T) {
node := newQueryNode() node := newQueryNode()
collectionID := rand.Int63n(1000000) collectionID := rand.Int63n(1000000)
@ -125,10 +144,18 @@ func TestLoadIndexService_PulsarAddress(t *testing.T) {
statsMs.CreatePulsarConsumers([]string{Params.StatsChannelName}, Params.MsgChannelSubName, msgstream.NewUnmarshalDispatcher(), Params.StatsReceiveBufSize) statsMs.CreatePulsarConsumers([]string{Params.StatsChannelName}, Params.MsgChannelSubName, msgstream.NewUnmarshalDispatcher(), Params.StatsReceiveBufSize)
statsMs.Start() statsMs.Start()
findFiledStats := false
for {
receiveMsg := msgstream.MsgStream(statsMs).Consume() receiveMsg := msgstream.MsgStream(statsMs).Consume()
assert.NotNil(t, receiveMsg) assert.NotNil(t, receiveMsg)
assert.NotEqual(t, len(receiveMsg.Msgs), 0) assert.NotEqual(t, len(receiveMsg.Msgs), 0)
statsMsg, ok := receiveMsg.Msgs[0].(*msgstream.QueryNodeStatsMsg)
for _, msg := range receiveMsg.Msgs {
statsMsg, ok := msg.(*msgstream.QueryNodeStatsMsg)
if statsMsg.FieldStats == nil || len(statsMsg.FieldStats) == 0 {
continue
}
findFiledStats = true
assert.Equal(t, ok, true) assert.Equal(t, ok, true)
assert.Equal(t, len(statsMsg.FieldStats), 1) assert.Equal(t, len(statsMsg.FieldStats), 1)
fieldStats0 := statsMsg.FieldStats[0] fieldStats0 := statsMsg.FieldStats[0]
@ -136,13 +163,19 @@ func TestLoadIndexService_PulsarAddress(t *testing.T) {
assert.Equal(t, fieldStats0.CollectionID, collectionID) assert.Equal(t, fieldStats0.CollectionID, collectionID)
assert.Equal(t, len(fieldStats0.IndexStats), 1) assert.Equal(t, len(fieldStats0.IndexStats), 1)
indexStats0 := fieldStats0.IndexStats[0] indexStats0 := fieldStats0.IndexStats[0]
params := indexStats0.IndexParams params := indexStats0.IndexParams
// sort index params by key // sort index params by key
sort.Slice(indexParams, func(i, j int) bool { return indexParams[i].Key < indexParams[j].Key }) sort.Slice(indexParams, func(i, j int) bool { return indexParams[i].Key < indexParams[j].Key })
indexEqual := node.loadIndexService.indexParamsEqual(params, indexParams) indexEqual := node.loadIndexService.indexParamsEqual(params, indexParams)
assert.Equal(t, indexEqual, true) assert.Equal(t, indexEqual, true)
}
if findFiledStats {
break
}
}
defer assert.Equal(t, findFiledStats, true)
<-node.queryNodeLoopCtx.Done() <-node.queryNodeLoopCtx.Done()
node.Close() node.Close()
} }