From f12366342fde72938366bf789483757a4e246eb0 Mon Sep 17 00:00:00 2001 From: "zhenshan.cao" Date: Wed, 25 Nov 2020 16:45:42 +0800 Subject: [PATCH] Fix bug and delete unused code Signed-off-by: zhenshan.cao --- go.mod | 1 + go.sum | 2 + internal/allocator/allocator.go | 2 +- internal/allocator/segment.go | 115 +++++++++++++++++----- internal/allocator/timestamp.go | 3 +- internal/core/src/segcore/segment_c.cpp | 49 ++------- internal/core/src/segcore/segment_c.h | 20 +--- internal/core/unittest/test_c_api.cpp | 18 ++-- internal/proxy/paramtable.go | 25 ++++- internal/proxy/proxy.go | 5 +- internal/proxy/proxy_test.go | 3 +- internal/proxy/repack_func.go | 104 +++++++++++++++---- internal/proxy/task.go | 13 ++- internal/proxy/task_scheduler.go | 18 +--- internal/reader/flow_graph_insert_node.go | 4 +- internal/reader/param_table.go | 4 +- internal/reader/segment.go | 37 +++---- internal/reader/segment_test.go | 1 + 18 files changed, 258 insertions(+), 166 deletions(-) diff --git a/go.mod b/go.mod index 6bf3a79b3b..5016aaf715 100644 --- a/go.mod +++ b/go.mod @@ -7,6 +7,7 @@ require ( github.com/apache/pulsar-client-go v0.1.1 github.com/aws/aws-sdk-go v1.30.8 github.com/coreos/etcd v3.3.25+incompatible // indirect + github.com/cznic/mathutil v0.0.0-20181122101859-297441e03548 github.com/frankban/quicktest v1.10.2 // indirect github.com/fsnotify/fsnotify v1.4.9 // indirect github.com/golang/groupcache v0.0.0-20200121045136-8c9f03a8e57e // indirect diff --git a/go.sum b/go.sum index 3f443fa53e..16af9ce827 100644 --- a/go.sum +++ b/go.sum @@ -65,6 +65,7 @@ github.com/coreos/pkg v0.0.0-20180928190104-399ea9e2e55f h1:lBNOc5arjvs8E5mO2tbp github.com/coreos/pkg v0.0.0-20180928190104-399ea9e2e55f/go.mod h1:E3G3o1h8I7cfcXa63jLwjI0eiQQMgzzUDFVpN/nH/eA= github.com/creack/pty v1.1.7/go.mod h1:lj5s0c3V2DBrqTV7llrYr5NG6My20zk30Fl46Y7DoTY= github.com/creack/pty v1.1.9/go.mod h1:oKZEueFk5CKHvIhNR5MUki03XCEU+Q6VDXinZuGJ33E= +github.com/cznic/mathutil v0.0.0-20181122101859-297441e03548 h1:iwZdTE0PVqJCos1vaoKsclOGD3ADKpshg3SRtYBbwso= github.com/cznic/mathutil v0.0.0-20181122101859-297441e03548/go.mod h1:e6NPNENfs9mPDVNRekM7lKScauxd5kXTr1Mfyig6TDM= github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= @@ -329,6 +330,7 @@ github.com/prometheus/procfs v0.0.8/go.mod h1:7Qr8sr6344vo1JqZ6HhLceV9o3AJ1Ff+Gx github.com/prometheus/procfs v0.1.3 h1:F0+tqvhOksq22sc6iCHF5WGlWjdwj92p0udFh1VFBS8= github.com/prometheus/procfs v0.1.3/go.mod h1:lV6e/gmhEcM9IjHGsFOCxxuZ+z1YqCvr4OA4YeYWdaU= github.com/prometheus/tsdb v0.7.1/go.mod h1:qhTCs0VvXwvX/y3TZrWD7rabWM+ijKTux40TwIPHuXU= +github.com/remyoudompheng/bigfft v0.0.0-20170806203942-52369c62f446 h1:/NRJ5vAYoqz+7sG51ubIDHXeWO8DlTSrToPu6q11ziA= github.com/remyoudompheng/bigfft v0.0.0-20170806203942-52369c62f446/go.mod h1:uYEyJGbgTkfkS4+E/PavXkNJcbFIpEtjt2B0KDQ5+9M= github.com/rogpeppe/fastuuid v0.0.0-20150106093220-6724a57986af/go.mod h1:XWv6SoW27p1b0cqNHllgS5HIMJraePCO15w5zCzIWYg= github.com/rogpeppe/go-internal v1.3.0/go.mod h1:M8bDsm7K2OlrFYOpmOWEs/qY81heoFRclV5y23lUDJ4= diff --git a/internal/allocator/allocator.go b/internal/allocator/allocator.go index 4232e20cba..e3105a68b5 100644 --- a/internal/allocator/allocator.go +++ b/internal/allocator/allocator.go @@ -57,7 +57,7 @@ type segRequest struct { count uint32 colName string partition string - segID UniqueID + segInfo map[UniqueID]uint32 channelID int32 } diff --git a/internal/allocator/segment.go b/internal/allocator/segment.go index 11aefe2736..1e2c565432 100644 --- a/internal/allocator/segment.go +++ b/internal/allocator/segment.go @@ -1,11 +1,15 @@ package allocator import ( + "container/list" "context" "fmt" "log" + "sort" "time" + "github.com/cznic/mathutil" + "github.com/zilliztech/milvus-distributed/internal/proto/commonpb" "github.com/zilliztech/milvus-distributed/internal/errors" @@ -18,7 +22,10 @@ const ( ) type assignInfo struct { - internalpb.SegIDAssignment + collName string + partitionTag string + channelID int32 + segInfo map[UniqueID]uint32 // segmentID->count map expireTime time.Time lastInsertTime time.Time } @@ -32,12 +39,16 @@ func (info *assignInfo) IsActive(now time.Time) bool { } func (info *assignInfo) IsEnough(count uint32) bool { - return info.Count >= count + total := uint32(0) + for _, count := range info.segInfo { + total += count + } + return total >= count } type SegIDAssigner struct { Allocator - assignInfos map[string][]*assignInfo // collectionName -> [] *assignInfo + assignInfos map[string]*list.List // collectionName -> *list.List segReqs []*internalpb.SegIDRequest canDoReqs []request } @@ -50,11 +61,8 @@ func NewSegIDAssigner(ctx context.Context, masterAddr string) (*SegIDAssigner, e cancel: cancel, masterAddress: masterAddr, countPerRPC: SegCountPerRPC, - //toDoReqs: []request, }, - assignInfos: make(map[string][]*assignInfo), - //segReqs: make([]*internalpb.SegIDRequest, maxConcurrentRequests), - //canDoReqs: make([]request, maxConcurrentRequests), + assignInfos: make(map[string]*list.List), } sa.tChan = &ticker{ updateInterval: time.Second, @@ -67,16 +75,17 @@ func NewSegIDAssigner(ctx context.Context, masterAddr string) (*SegIDAssigner, e func (sa *SegIDAssigner) collectExpired() { now := time.Now() - for _, colInfos := range sa.assignInfos { - for _, assign := range colInfos { + for _, info := range sa.assignInfos { + for e := info.Front(); e != nil; e = e.Next() { + assign := e.Value.(*assignInfo) if !assign.IsActive(now) || !assign.IsExpired(now) { continue } sa.segReqs = append(sa.segReqs, &internalpb.SegIDRequest{ - ChannelID: assign.ChannelID, + ChannelID: assign.channelID, Count: sa.countPerRPC, - CollName: assign.CollName, - PartitionTag: assign.PartitionTag, + CollName: assign.collName, + PartitionTag: assign.partitionTag, }) } } @@ -88,7 +97,6 @@ func (sa *SegIDAssigner) checkToDoReqs() { } now := time.Now() for _, req := range sa.toDoReqs { - fmt.Println("DDDDD????", req) segRequest := req.(*segRequest) assign := sa.getAssign(segRequest.colName, segRequest.partition, segRequest.channelID) if assign == nil || assign.IsExpired(now) || !assign.IsEnough(segRequest.count) { @@ -102,13 +110,36 @@ func (sa *SegIDAssigner) checkToDoReqs() { } } +func (sa *SegIDAssigner) removeSegInfo(colName, partition string, channelID int32) { + assignInfos, ok := sa.assignInfos[colName] + if !ok { + return + } + + cnt := assignInfos.Len() + if cnt == 0 { + return + } + + for e := assignInfos.Front(); e != nil; e = e.Next() { + assign := e.Value.(*assignInfo) + if assign.partitionTag != partition || assign.channelID != channelID { + continue + } + assignInfos.Remove(e) + } + +} + func (sa *SegIDAssigner) getAssign(colName, partition string, channelID int32) *assignInfo { - colInfos, ok := sa.assignInfos[colName] + assignInfos, ok := sa.assignInfos[colName] if !ok { return nil } - for _, info := range colInfos { - if info.PartitionTag != partition || info.ChannelID != channelID { + + for e := assignInfos.Front(); e != nil; e = e.Next() { + info := e.Value.(*assignInfo) + if info.partitionTag != partition || info.channelID != channelID { continue } return info @@ -151,19 +182,26 @@ func (sa *SegIDAssigner) syncSegments() { now := time.Now() expiredTime := now.Add(time.Millisecond * time.Duration(resp.ExpireDuration)) + for _, info := range resp.PerChannelAssignment { + sa.removeSegInfo(info.CollName, info.PartitionTag, info.ChannelID) + } + for _, info := range resp.PerChannelAssignment { assign := sa.getAssign(info.CollName, info.PartitionTag, info.ChannelID) if assign == nil { colInfos := sa.assignInfos[info.CollName] + segInfo := make(map[UniqueID]uint32) + segInfo[info.SegID] = info.Count newAssign := &assignInfo{ - SegIDAssignment: *info, - expireTime: expiredTime, - lastInsertTime: now, + collName: info.CollName, + partitionTag: info.PartitionTag, + channelID: info.ChannelID, + segInfo: segInfo, } - colInfos = append(colInfos, newAssign) + colInfos.PushBack(newAssign) sa.assignInfos[info.CollName] = colInfos } else { - assign.SegIDAssignment = *info + assign.segInfo[info.SegID] = info.Count assign.expireTime = expiredTime assign.lastInsertTime = now } @@ -181,13 +219,38 @@ func (sa *SegIDAssigner) processFunc(req request) error { if assign == nil { return errors.New("Failed to GetSegmentID") } - segRequest.segID = assign.SegID - assign.Count -= segRequest.count + + keys := make([]UniqueID, len(assign.segInfo)) + i := 0 + for key := range assign.segInfo { + keys[i] = key + i++ + } + reqCount := segRequest.count + + resultSegInfo := make(map[UniqueID]uint32) + sort.Slice(keys, func(i, j int) bool { return keys[i] < keys[j] }) + for _, key := range keys { + if reqCount <= 0 { + break + } + cur := assign.segInfo[key] + minCnt := mathutil.MinUint32(cur, reqCount) + resultSegInfo[key] = minCnt + cur -= minCnt + reqCount -= minCnt + if cur <= 0 { + delete(assign.segInfo, key) + } else { + assign.segInfo[key] = cur + } + } + segRequest.segInfo = resultSegInfo fmt.Println("process segmentID") return nil } -func (sa *SegIDAssigner) GetSegmentID(colName, partition string, channelID int32, count uint32) (UniqueID, error) { +func (sa *SegIDAssigner) GetSegmentID(colName, partition string, channelID int32, count uint32) (map[UniqueID]uint32, error) { req := &segRequest{ baseRequest: baseRequest{done: make(chan error), valid: false}, colName: colName, @@ -199,7 +262,7 @@ func (sa *SegIDAssigner) GetSegmentID(colName, partition string, channelID int32 req.Wait() if !req.IsValid() { - return 0, errors.New("GetSegmentID Failed") + return nil, errors.New("GetSegmentID Failed") } - return req.segID, nil + return req.segInfo, nil } diff --git a/internal/allocator/timestamp.go b/internal/allocator/timestamp.go index 6421ad71b3..b96e45fbfb 100644 --- a/internal/allocator/timestamp.go +++ b/internal/allocator/timestamp.go @@ -13,7 +13,7 @@ import ( type Timestamp = typeutil.Timestamp const ( - tsCountPerRPC = 2 << 18 * 10 + tsCountPerRPC = 2 << 15 ) type TimestampAllocator struct { @@ -37,6 +37,7 @@ func NewTimestampAllocator(ctx context.Context, masterAddr string) (*TimestampAl } a.Allocator.syncFunc = a.syncTs a.Allocator.processFunc = a.processFunc + a.Allocator.checkFunc = a.checkFunc return a, nil } diff --git a/internal/core/src/segcore/segment_c.cpp b/internal/core/src/segcore/segment_c.cpp index cdaa967228..fc3be7f6c8 100644 --- a/internal/core/src/segcore/segment_c.cpp +++ b/internal/core/src/segcore/segment_c.cpp @@ -18,7 +18,6 @@ #include #include #include -#include CSegmentBase NewSegment(CCollection collection, uint64_t segment_id) { @@ -42,7 +41,7 @@ DeleteSegment(CSegmentBase segment) { ////////////////////////////////////////////////////////////////// -CStatus +int Insert(CSegmentBase c_segment, int64_t reserved_offset, int64_t size, @@ -58,22 +57,11 @@ Insert(CSegmentBase c_segment, dataChunk.sizeof_per_row = sizeof_per_row; dataChunk.count = count; - try { - auto res = segment->Insert(reserved_offset, size, row_ids, timestamps, dataChunk); - - auto status = CStatus(); - status.error_code = Success; - status.error_msg = ""; - return status; - } catch (std::runtime_error& e) { - auto status = CStatus(); - status.error_code = UnexpectedException; - status.error_msg = strdup(e.what()); - return status; - } + auto res = segment->Insert(reserved_offset, size, row_ids, timestamps, dataChunk); // TODO: delete print // std::cout << "do segment insert, sizeof_per_row = " << sizeof_per_row << std::endl; + return res.code(); } int64_t @@ -85,24 +73,13 @@ PreInsert(CSegmentBase c_segment, int64_t size) { return segment->PreInsert(size); } -CStatus +int Delete( CSegmentBase c_segment, int64_t reserved_offset, int64_t size, const int64_t* row_ids, const uint64_t* timestamps) { auto segment = (milvus::segcore::SegmentBase*)c_segment; - try { - auto res = segment->Delete(reserved_offset, size, row_ids, timestamps); - - auto status = CStatus(); - status.error_code = Success; - status.error_msg = ""; - return status; - } catch (std::runtime_error& e) { - auto status = CStatus(); - status.error_code = UnexpectedException; - status.error_msg = strdup(e.what()); - return status; - } + auto res = segment->Delete(reserved_offset, size, row_ids, timestamps); + return res.code(); } int64_t @@ -114,7 +91,7 @@ PreDelete(CSegmentBase c_segment, int64_t size) { return segment->PreDelete(size); } -CStatus +int Search(CSegmentBase c_segment, CPlan c_plan, CPlaceholderGroup* c_placeholder_groups, @@ -130,22 +107,14 @@ Search(CSegmentBase c_segment, } milvus::segcore::QueryResult query_result; - auto status = CStatus(); - try { - auto res = segment->Search(plan, placeholder_groups.data(), timestamps, num_groups, query_result); - status.error_code = Success; - status.error_msg = ""; - } catch (std::runtime_error& e) { - status.error_code = UnexpectedException; - status.error_msg = strdup(e.what()); - } + auto res = segment->Search(plan, placeholder_groups.data(), timestamps, num_groups, query_result); // result_ids and result_distances have been allocated memory in goLang, // so we don't need to malloc here. memcpy(result_ids, query_result.result_ids_.data(), query_result.get_row_count() * sizeof(int64_t)); memcpy(result_distances, query_result.result_distances_.data(), query_result.get_row_count() * sizeof(float)); - return status; + return res.code(); } ////////////////////////////////////////////////////////////////// diff --git a/internal/core/src/segcore/segment_c.h b/internal/core/src/segcore/segment_c.h index c2af7e8305..5e681bc689 100644 --- a/internal/core/src/segcore/segment_c.h +++ b/internal/core/src/segcore/segment_c.h @@ -14,24 +14,12 @@ extern "C" { #endif #include -#include -#include - #include "segcore/collection_c.h" #include "segcore/plan_c.h" +#include typedef void* CSegmentBase; -enum ErrorCode { - Success = 0, - UnexpectedException = 1, -}; - -typedef struct CStatus { - int error_code; - const char* error_msg; -} CStatus; - CSegmentBase NewSegment(CCollection collection, uint64_t segment_id); @@ -40,7 +28,7 @@ DeleteSegment(CSegmentBase segment); ////////////////////////////////////////////////////////////////// -CStatus +int Insert(CSegmentBase c_segment, int64_t reserved_offset, int64_t size, @@ -53,14 +41,14 @@ Insert(CSegmentBase c_segment, int64_t PreInsert(CSegmentBase c_segment, int64_t size); -CStatus +int Delete( CSegmentBase c_segment, int64_t reserved_offset, int64_t size, const int64_t* row_ids, const uint64_t* timestamps); int64_t PreDelete(CSegmentBase c_segment, int64_t size); -CStatus +int Search(CSegmentBase c_segment, CPlan plan, CPlaceholderGroup* placeholder_groups, diff --git a/internal/core/unittest/test_c_api.cpp b/internal/core/unittest/test_c_api.cpp index b16d545741..d25c4a8a3b 100644 --- a/internal/core/unittest/test_c_api.cpp +++ b/internal/core/unittest/test_c_api.cpp @@ -65,7 +65,7 @@ TEST(CApiTest, InsertTest) { auto res = Insert(segment, offset, N, uids.data(), timestamps.data(), raw_data.data(), (int)line_sizeof, N); - assert(res.error_code == Success); + assert(res == 0); DeleteCollection(collection); DeleteSegment(segment); @@ -82,7 +82,7 @@ TEST(CApiTest, DeleteTest) { auto offset = PreDelete(segment, 3); auto del_res = Delete(segment, offset, 3, delete_row_ids, delete_timestamps); - assert(del_res.error_code == Success); + assert(del_res == 0); DeleteCollection(collection); DeleteSegment(segment); @@ -116,7 +116,7 @@ TEST(CApiTest, SearchTest) { auto offset = PreInsert(segment, N); auto ins_res = Insert(segment, offset, N, uids.data(), timestamps.data(), raw_data.data(), (int)line_sizeof, N); - assert(ins_res.error_code == Success); + assert(ins_res == 0); const char* dsl_string = R"( { @@ -163,7 +163,7 @@ TEST(CApiTest, SearchTest) { float result_distances[100]; auto sea_res = Search(segment, plan, placeholderGroups.data(), timestamps.data(), 1, result_ids, result_distances); - assert(sea_res.error_code == Success); + assert(sea_res == 0); DeletePlan(plan); DeletePlaceholderGroup(placeholderGroup); @@ -199,7 +199,7 @@ TEST(CApiTest, BuildIndexTest) { auto offset = PreInsert(segment, N); auto ins_res = Insert(segment, offset, N, uids.data(), timestamps.data(), raw_data.data(), (int)line_sizeof, N); - assert(ins_res.error_code == Success); + assert(ins_res == 0); // TODO: add index ptr Close(segment); @@ -250,7 +250,7 @@ TEST(CApiTest, BuildIndexTest) { float result_distances[100]; auto sea_res = Search(segment, plan, placeholderGroups.data(), timestamps.data(), 1, result_ids, result_distances); - assert(sea_res.error_code == Success); + assert(sea_res == 0); DeletePlan(plan); DeletePlaceholderGroup(placeholderGroup); @@ -315,7 +315,7 @@ TEST(CApiTest, GetMemoryUsageInBytesTest) { auto res = Insert(segment, offset, N, uids.data(), timestamps.data(), raw_data.data(), (int)line_sizeof, N); - assert(res.error_code == Success); + assert(res == 0); auto memory_usage_size = GetMemoryUsageInBytes(segment); @@ -482,7 +482,7 @@ TEST(CApiTest, GetDeletedCountTest) { auto offset = PreDelete(segment, 3); auto del_res = Delete(segment, offset, 3, delete_row_ids, delete_timestamps); - assert(del_res.error_code == Success); + assert(del_res == 0); // TODO: assert(deleted_count == len(delete_row_ids)) auto deleted_count = GetDeletedCount(segment); @@ -502,7 +502,7 @@ TEST(CApiTest, GetRowCountTest) { auto line_sizeof = (sizeof(int) + sizeof(float) * 16); auto offset = PreInsert(segment, N); auto res = Insert(segment, offset, N, uids.data(), timestamps.data(), raw_data.data(), (int)line_sizeof, N); - assert(res.error_code == Success); + assert(res == 0); auto row_count = GetRowCount(segment); assert(row_count == N); diff --git a/internal/proxy/paramtable.go b/internal/proxy/paramtable.go index ba60abe634..c53c845096 100644 --- a/internal/proxy/paramtable.go +++ b/internal/proxy/paramtable.go @@ -96,6 +96,27 @@ func (pt *ParamTable) ProxyIDList() []UniqueID { return ret } +func (pt *ParamTable) queryNodeNum() int { + return len(pt.queryNodeIDList()) +} + +func (pt *ParamTable) queryNodeIDList() []UniqueID { + queryNodeIDStr, err := pt.Load("nodeID.queryNodeIDList") + if err != nil { + panic(err) + } + var ret []UniqueID + queryNodeIDs := strings.Split(queryNodeIDStr, ",") + for _, i := range queryNodeIDs { + v, err := strconv.Atoi(i) + if err != nil { + log.Panicf("load proxy id list error, %s", err.Error()) + } + ret = append(ret, UniqueID(v)) + } + return ret +} + func (pt *ParamTable) ProxyID() UniqueID { proxyID, err := pt.Load("_proxyID") if err != nil { @@ -396,11 +417,11 @@ func (pt *ParamTable) searchChannelNames() []string { } func (pt *ParamTable) searchResultChannelNames() []string { - ch, err := pt.Load("msgChannel.chanNamePrefix.search") + ch, err := pt.Load("msgChannel.chanNamePrefix.searchResult") if err != nil { log.Fatal(err) } - channelRange, err := pt.Load("msgChannel.channelRange.search") + channelRange, err := pt.Load("msgChannel.channelRange.searchResult") if err != nil { panic(err) } diff --git a/internal/proxy/proxy.go b/internal/proxy/proxy.go index 02e724f821..4f89642eaf 100644 --- a/internal/proxy/proxy.go +++ b/internal/proxy/proxy.go @@ -55,12 +55,11 @@ func CreateProxy(ctx context.Context) (*Proxy, error) { proxyLoopCancel: cancel, } - // TODO: use config instead pulsarAddress := Params.PulsarAddress() p.queryMsgStream = msgstream.NewPulsarMsgStream(p.proxyLoopCtx, Params.MsgStreamSearchBufSize()) p.queryMsgStream.SetPulsarClient(pulsarAddress) - p.queryMsgStream.CreatePulsarProducers(Params.SearchChannelNames()) + p.queryMsgStream.CreatePulsarProducers(Params.searchChannelNames()) masterAddr := Params.MasterAddress() idAllocator, err := allocator.NewIDAllocator(p.proxyLoopCtx, masterAddr) @@ -84,7 +83,7 @@ func CreateProxy(ctx context.Context) (*Proxy, error) { p.manipulationMsgStream = msgstream.NewPulsarMsgStream(p.proxyLoopCtx, Params.MsgStreamInsertBufSize()) p.manipulationMsgStream.SetPulsarClient(pulsarAddress) - p.manipulationMsgStream.CreatePulsarProducers(Params.InsertChannelNames()) + p.manipulationMsgStream.CreatePulsarProducers(Params.insertChannelNames()) repackFuncImpl := func(tsMsgs []msgstream.TsMsg, hashKeys [][]int32) (map[int32]*msgstream.MsgPack, error) { return insertRepackFunc(tsMsgs, hashKeys, p.segAssigner, false) } diff --git a/internal/proxy/proxy_test.go b/internal/proxy/proxy_test.go index 09c299bc5f..4ffcaf18e0 100644 --- a/internal/proxy/proxy_test.go +++ b/internal/proxy/proxy_test.go @@ -229,7 +229,7 @@ func TestProxy_Insert(t *testing.T) { collectionName := "CreateCollection" + strconv.FormatInt(int64(i), 10) req := &servicepb.RowBatch{ CollectionName: collectionName, - PartitionTag: "", + PartitionTag: "haha", RowData: make([]*commonpb.Blob, 0), HashKeys: make([]int32, 0), } @@ -237,6 +237,7 @@ func TestProxy_Insert(t *testing.T) { wg.Add(1) go func(group *sync.WaitGroup) { defer group.Done() + createCollection(t, collectionName) has := hasCollection(t, collectionName) if has { resp, err := proxyClient.Insert(ctx, req) diff --git a/internal/proxy/repack_func.go b/internal/proxy/repack_func.go index ec12f475e2..781c188936 100644 --- a/internal/proxy/repack_func.go +++ b/internal/proxy/repack_func.go @@ -1,6 +1,9 @@ package proxy import ( + "log" + "sort" + "github.com/zilliztech/milvus-distributed/internal/allocator" "github.com/zilliztech/milvus-distributed/internal/errors" "github.com/zilliztech/milvus-distributed/internal/msgstream" @@ -15,6 +18,9 @@ func insertRepackFunc(tsMsgs []msgstream.TsMsg, result := make(map[int32]*msgstream.MsgPack) + channelCountMap := make(map[UniqueID]map[int32]uint32) // reqID --> channelID to count + reqSchemaMap := make(map[UniqueID][]string) + for i, request := range tsMsgs { if request.Type() != internalpb.MsgType_kInsert { return nil, errors.New(string("msg's must be Insert")) @@ -23,8 +29,8 @@ func insertRepackFunc(tsMsgs []msgstream.TsMsg, if !ok { return nil, errors.New(string("msg's must be Insert")) } - keys := hashKeys[i] + keys := hashKeys[i] timestampLen := len(insertRequest.Timestamps) rowIDLen := len(insertRequest.RowIDs) rowDataLen := len(insertRequest.RowData) @@ -34,10 +40,84 @@ func insertRepackFunc(tsMsgs []msgstream.TsMsg, return nil, errors.New(string("the length of hashValue, timestamps, rowIDs, RowData are not equal")) } + reqID := insertRequest.ReqID + if _, ok := channelCountMap[reqID]; !ok { + channelCountMap[reqID] = make(map[int32]uint32) + } + + if _, ok := reqSchemaMap[reqID]; !ok { + reqSchemaMap[reqID] = []string{insertRequest.CollectionName, insertRequest.PartitionTag} + } + + for _, channelID := range keys { + channelCountMap[reqID][channelID]++ + } + + } + + reqSegCountMap := make(map[UniqueID]map[int32]map[UniqueID]uint32) + + for reqID, countInfo := range channelCountMap { + schema := reqSchemaMap[reqID] + collName, partitionTag := schema[0], schema[1] + for channelID, count := range countInfo { + mapInfo, err := segIDAssigner.GetSegmentID(collName, partitionTag, channelID, count) + if err != nil { + return nil, err + } + reqSegCountMap[reqID][channelID] = mapInfo + } + } + + reqSegAccumulateCountMap := make(map[UniqueID]map[int32][]uint32) + reqSegIDMap := make(map[UniqueID]map[int32][]UniqueID) + reqSegAllocateCounter := make(map[UniqueID]map[int32]uint32) + + for reqID, channelInfo := range reqSegCountMap { + for channelID, segInfo := range channelInfo { + reqSegAllocateCounter[reqID][channelID] = 0 + keys := make([]UniqueID, len(segInfo)) + i := 0 + for key := range segInfo { + keys[i] = key + i++ + } + sort.Slice(keys, func(i, j int) bool { return keys[i] < keys[j] }) + accumulate := uint32(0) + for _, key := range keys { + accumulate += segInfo[key] + reqSegAccumulateCountMap[reqID][channelID] = append( + reqSegAccumulateCountMap[reqID][channelID], + accumulate, + ) + reqSegIDMap[reqID][channelID] = append( + reqSegIDMap[reqID][channelID], + key, + ) + } + } + } + + var getSegmentID = func(reqID UniqueID, channelID int32) UniqueID { + reqSegAllocateCounter[reqID][channelID]++ + cur := reqSegAllocateCounter[reqID][channelID] + accumulateSlice := reqSegAccumulateCountMap[reqID][channelID] + segIDSlice := reqSegIDMap[reqID][channelID] + for index, count := range accumulateSlice { + if cur <= count { + return segIDSlice[index] + } + } + log.Panic("Can't Found SegmentID") + return 0 + } + + for i, request := range tsMsgs { + insertRequest := request.(*msgstream.InsertMsg) + keys := hashKeys[i] reqID := insertRequest.ReqID collectionName := insertRequest.CollectionName partitionTag := insertRequest.PartitionTag - channelID := insertRequest.ChannelID proxyID := insertRequest.ProxyID for index, key := range keys { ts := insertRequest.Timestamps[index] @@ -48,13 +128,14 @@ func insertRepackFunc(tsMsgs []msgstream.TsMsg, msgPack := msgstream.MsgPack{} result[key] = &msgPack } + segmentID := getSegmentID(reqID, key) sliceRequest := internalpb.InsertRequest{ MsgType: internalpb.MsgType_kInsert, ReqID: reqID, CollectionName: collectionName, PartitionTag: partitionTag, - SegmentID: 0, // will be assigned later if together - ChannelID: channelID, + SegmentID: segmentID, + ChannelID: int64(key), ProxyID: proxyID, Timestamps: []uint64{ts}, RowIDs: []int64{rowID}, @@ -73,25 +154,10 @@ func insertRepackFunc(tsMsgs []msgstream.TsMsg, accMsgs.RowData = append(accMsgs.RowData, row) } } else { // every row is a message - segID, _ := segIDAssigner.GetSegmentID(collectionName, partitionTag, int32(channelID), 1) - insertMsg.SegmentID = segID result[key].Msgs = append(result[key].Msgs, insertMsg) } } } - if together { - for key := range result { - insertMsg, _ := result[key].Msgs[0].(*msgstream.InsertMsg) - rowNums := len(insertMsg.RowIDs) - collectionName := insertMsg.CollectionName - partitionTag := insertMsg.PartitionTag - channelID := insertMsg.ChannelID - segID, _ := segIDAssigner.GetSegmentID(collectionName, partitionTag, int32(channelID), uint32(rowNums)) - insertMsg.SegmentID = segID - result[key].Msgs[0] = insertMsg - } - } - return result, nil } diff --git a/internal/proxy/task.go b/internal/proxy/task.go index e2f83e50b9..65e4e214ee 100644 --- a/internal/proxy/task.go +++ b/internal/proxy/task.go @@ -34,7 +34,6 @@ type BaseInsertTask = msgstream.InsertMsg type InsertTask struct { BaseInsertTask Condition - ts Timestamp result *servicepb.IntegerRangeResponse manipulationMsgStream *msgstream.PulsarMsgStream ctx context.Context @@ -46,15 +45,21 @@ func (it *InsertTask) SetID(uid UniqueID) { } func (it *InsertTask) SetTs(ts Timestamp) { - it.ts = ts + rowNum := len(it.RowData) + it.Timestamps = make([]uint64, rowNum) + for index := range it.Timestamps { + it.Timestamps[index] = ts + } + it.BeginTimestamp = ts + it.EndTimestamp = ts } func (it *InsertTask) BeginTs() Timestamp { - return it.ts + return it.BeginTimestamp } func (it *InsertTask) EndTs() Timestamp { - return it.ts + return it.EndTimestamp } func (it *InsertTask) ID() UniqueID { diff --git a/internal/proxy/task_scheduler.go b/internal/proxy/task_scheduler.go index 4eccf249e1..f08d4672ff 100644 --- a/internal/proxy/task_scheduler.go +++ b/internal/proxy/task_scheduler.go @@ -186,16 +186,7 @@ type DqTaskQueue struct { func (queue *DdTaskQueue) Enqueue(t task) error { queue.lock.Lock() defer queue.lock.Unlock() - - ts, _ := queue.sched.tsoAllocator.AllocOne() - log.Printf("[Proxy] allocate timestamp: %v", ts) - t.SetTs(ts) - - reqID, _ := queue.sched.idAllocator.AllocOne() - log.Printf("[Proxy] allocate reqID: %v", reqID) - t.SetID(reqID) - - return queue.addUnissuedTask(t) + return queue.BaseTaskQueue.Enqueue(t) } func NewDdTaskQueue(sched *TaskScheduler) *DdTaskQueue { @@ -369,14 +360,14 @@ func (sched *TaskScheduler) queryLoop() { func (sched *TaskScheduler) queryResultLoop() { defer sched.wg.Done() - // TODO: use config instead unmarshal := msgstream.NewUnmarshalDispatcher() queryResultMsgStream := msgstream.NewPulsarMsgStream(sched.ctx, Params.MsgStreamSearchResultBufSize()) queryResultMsgStream.SetPulsarClient(Params.PulsarAddress()) - queryResultMsgStream.CreatePulsarConsumers(Params.SearchResultChannelNames(), + queryResultMsgStream.CreatePulsarConsumers(Params.searchResultChannelNames(), Params.ProxySubName(), unmarshal, Params.MsgStreamSearchResultPulsarBufSize()) + queryNodeNum := Params.queryNodeNum() queryResultMsgStream.Start() defer queryResultMsgStream.Close() @@ -401,8 +392,7 @@ func (sched *TaskScheduler) queryResultLoop() { queryResultBuf[reqID] = make([]*internalpb.SearchResult, 0) } queryResultBuf[reqID] = append(queryResultBuf[reqID], &searchResultMsg.SearchResult) - if len(queryResultBuf[reqID]) == 4 { - // TODO: use the number of query node instead + if len(queryResultBuf[reqID]) == queryNodeNum { t := sched.getTaskByReqID(reqID) if t != nil { qt, ok := t.(*QueryTask) diff --git a/internal/reader/flow_graph_insert_node.go b/internal/reader/flow_graph_insert_node.go index 3f0f2be310..2d56e0bd57 100644 --- a/internal/reader/flow_graph_insert_node.go +++ b/internal/reader/flow_graph_insert_node.go @@ -106,7 +106,6 @@ func (iNode *insertNode) insert(insertData *InsertData, segmentID int64, wg *syn if err != nil { log.Println("cannot find segment:", segmentID) // TODO: add error handling - wg.Done() return } @@ -117,9 +116,8 @@ func (iNode *insertNode) insert(insertData *InsertData, segmentID int64, wg *syn err = targetSegment.segmentInsert(offsets, &ids, ×tamps, &records) if err != nil { - log.Println(err) + log.Println("insert failed") // TODO: add error handling - wg.Done() return } diff --git a/internal/reader/param_table.go b/internal/reader/param_table.go index c104130afb..abae38270b 100644 --- a/internal/reader/param_table.go +++ b/internal/reader/param_table.go @@ -273,11 +273,11 @@ func (p *ParamTable) searchChannelNames() []string { } func (p *ParamTable) searchResultChannelNames() []string { - ch, err := p.Load("msgChannel.chanNamePrefix.search") + ch, err := p.Load("msgChannel.chanNamePrefix.searchResult") if err != nil { log.Fatal(err) } - channelRange, err := p.Load("msgChannel.channelRange.search") + channelRange, err := p.Load("msgChannel.channelRange.searchResult") if err != nil { panic(err) } diff --git a/internal/reader/segment.go b/internal/reader/segment.go index a1921788ad..1c7ab40ee2 100644 --- a/internal/reader/segment.go +++ b/internal/reader/segment.go @@ -109,7 +109,7 @@ func (s *Segment) segmentPreDelete(numOfRecords int) int64 { //-------------------------------------------------------------------------------------- dm & search functions func (s *Segment) segmentInsert(offset int64, entityIDs *[]UniqueID, timestamps *[]Timestamp, records *[]*commonpb.Blob) error { /* - CStatus + int Insert(CSegmentBase c_segment, long int reserved_offset, signed long int size, @@ -148,12 +148,8 @@ func (s *Segment) segmentInsert(offset int64, entityIDs *[]UniqueID, timestamps cSizeofPerRow, cNumOfRows) - errorCode := status.error_code - - if errorCode != 0 { - errorMsg := C.GoString(status.error_msg) - defer C.free(unsafe.Pointer(status.error_msg)) - return errors.New("Insert failed, C runtime error detected, error code = " + strconv.Itoa(int(errorCode)) + ", error msg = " + errorMsg) + if status != 0 { + return errors.New("Insert failed, error code = " + strconv.Itoa(int(status))) } s.recentlyModified = true @@ -162,7 +158,7 @@ func (s *Segment) segmentInsert(offset int64, entityIDs *[]UniqueID, timestamps func (s *Segment) segmentDelete(offset int64, entityIDs *[]UniqueID, timestamps *[]Timestamp) error { /* - CStatus + int Delete(CSegmentBase c_segment, long int reserved_offset, long size, @@ -176,12 +172,8 @@ func (s *Segment) segmentDelete(offset int64, entityIDs *[]UniqueID, timestamps var status = C.Delete(s.segmentPtr, cOffset, cSize, cEntityIdsPtr, cTimestampsPtr) - errorCode := status.error_code - - if errorCode != 0 { - errorMsg := C.GoString(status.error_msg) - defer C.free(unsafe.Pointer(status.error_msg)) - return errors.New("Delete failed, C runtime error detected, error code = " + strconv.Itoa(int(errorCode)) + ", error msg = " + errorMsg) + if status != 0 { + return errors.New("Delete failed, error code = " + strconv.Itoa(int(status))) } return nil @@ -195,8 +187,7 @@ func (s *Segment) segmentSearch(plan *Plan, numQueries int64, topK int64) error { /* - CStatus - Search(void* plan, + void* Search(void* plan, void* placeholder_groups, uint64_t* timestamps, int num_groups, @@ -220,20 +211,16 @@ func (s *Segment) segmentSearch(plan *Plan, var cNumGroups = C.int(len(placeHolderGroups)) var status = C.Search(s.segmentPtr, plan.cPlan, cPlaceHolder, cTimestamp, cNumGroups, cNewResultIds, cNewResultDistances) - errorCode := status.error_code - - if errorCode != 0 { - errorMsg := C.GoString(status.error_msg) - defer C.free(unsafe.Pointer(status.error_msg)) - return errors.New("Search failed, C runtime error detected, error code = " + strconv.Itoa(int(errorCode)) + ", error msg = " + errorMsg) + if status != 0 { + return errors.New("search failed, error code = " + strconv.Itoa(int(status))) } cNumQueries := C.long(numQueries) cTopK := C.long(topK) // reduce search result - mergeStatus := C.MergeInto(cNumQueries, cTopK, cResultDistances, cResultIds, cNewResultDistances, cNewResultIds) - if mergeStatus != 0 { - return errors.New("merge search result failed, error code = " + strconv.Itoa(int(mergeStatus))) + status = C.MergeInto(cNumQueries, cTopK, cResultDistances, cResultIds, cNewResultDistances, cNewResultIds) + if status != 0 { + return errors.New("merge search result failed, error code = " + strconv.Itoa(int(status))) } return nil } diff --git a/internal/reader/segment_test.go b/internal/reader/segment_test.go index b970c18e30..c14302ff1f 100644 --- a/internal/reader/segment_test.go +++ b/internal/reader/segment_test.go @@ -463,6 +463,7 @@ func TestSegment_segmentInsert(t *testing.T) { err := segment.segmentInsert(offset, &ids, ×tamps, &records) assert.NoError(t, err) + deleteSegment(segment) deleteCollection(collection) }