From c71cd40f68913a4dd814f18e3f4146daf1aff7db Mon Sep 17 00:00:00 2001 From: sunby Date: Sat, 19 Dec 2020 12:55:24 +0800 Subject: [PATCH] Refactor segment manager Signed-off-by: sunby --- Makefile | 2 +- internal/core/src/query/Plan.cpp | 259 +++------- .../src/query/visitors/ExecExprVisitor.cpp | 1 - internal/core/unittest/test_expr.cpp | 101 ---- internal/master/grpc_service.go | 2 +- internal/master/master.go | 36 +- internal/master/persistent_scheduler.go | 30 ++ internal/master/segment_assigner.go | 187 ++++++++ internal/master/segment_manager.go | 452 +++++++++--------- internal/master/segment_manager_test.go | 338 ++++++------- internal/master/stats_processor.go | 58 +++ internal/master/time_sync_producer.go | 21 +- internal/util/typeutil/schema.go | 36 ++ 13 files changed, 794 insertions(+), 729 deletions(-) create mode 100644 internal/master/persistent_scheduler.go create mode 100644 internal/master/segment_assigner.go create mode 100644 internal/master/stats_processor.go create mode 100644 internal/util/typeutil/schema.go diff --git a/Makefile b/Makefile index 1e19455f47..c2081907a1 100644 --- a/Makefile +++ b/Makefile @@ -61,7 +61,7 @@ ruleguard: verifiers: getdeps cppcheck fmt lint ruleguard # Builds various components locally. -build-go: build-cpp +build-go: @echo "Building each component's binary to './bin'" @echo "Building master ..." @mkdir -p $(INSTALL_PATH) && go env -w CGO_ENABLED="0" && GO111MODULE=on $(GO) build -o $(INSTALL_PATH)/master $(PWD)/cmd/master/main.go 1>/dev/null diff --git a/internal/core/src/query/Plan.cpp b/internal/core/src/query/Plan.cpp index 5e0b07a792..a8d21f7359 100644 --- a/internal/core/src/query/Plan.cpp +++ b/internal/core/src/query/Plan.cpp @@ -20,7 +20,6 @@ #include #include #include -#include namespace milvus::query { @@ -40,8 +39,10 @@ const std::map RangeExpr::mapping_ = { class Parser { public: - friend std::unique_ptr - CreatePlan(const Schema& schema, const std::string& dsl_str); + static std::unique_ptr + CreatePlan(const Schema& schema, const std::string& dsl_str) { + return Parser(schema).CreatePlanImpl(dsl_str); + } private: std::unique_ptr @@ -50,55 +51,29 @@ class Parser { explicit Parser(const Schema& schema) : schema(schema) { } - // vector node parser, should be called exactly once per pass. std::unique_ptr ParseVecNode(const Json& out_body); - // Dispatcher of all parse function - // NOTE: when nullptr, it is a pure vector node - ExprPtr - ParseAnyNode(const Json& body); - - ExprPtr - ParseMustNode(const Json& body); - - ExprPtr - ParseShouldNode(const Json& body); - - ExprPtr - ParseShouldNotNode(const Json& body); - - // parse the value of "should"/"must"/"should_not" entry - std::vector - ParseItemList(const Json& body); - - // parse the value of "range" entry - ExprPtr - ParseRangeNode(const Json& out_body); - - // parse the value of "term" entry - ExprPtr - ParseTermNode(const Json& out_body); - - private: - // template implementation of leaf parser - // used by corresponding parser - template - ExprPtr + std::unique_ptr ParseRangeNodeImpl(const std::string& field_name, const Json& body); template - ExprPtr + std::unique_ptr ParseTermNodeImpl(const std::string& field_name, const Json& body); + std::unique_ptr + ParseRangeNode(const Json& out_body); + + std::unique_ptr + ParseTermNode(const Json& out_body); + private: const Schema& schema; std::map tag2field_; // PlaceholderName -> FieldId - std::optional> vector_node_opt_; }; -ExprPtr +std::unique_ptr Parser::ParseRangeNode(const Json& out_body) { Assert(out_body.is_object()); Assert(out_body.size() == 1); @@ -109,8 +84,9 @@ Parser::ParseRangeNode(const Json& out_body) { Assert(!field_is_vector(data_type)); switch (data_type) { - case DataType::BOOL: + case DataType::BOOL: { return ParseRangeNodeImpl(field_name, body); + } case DataType::INT8: return ParseRangeNodeImpl(field_name, body); case DataType::INT16: @@ -130,22 +106,51 @@ Parser::ParseRangeNode(const Json& out_body) { std::unique_ptr Parser::CreatePlanImpl(const std::string& dsl_str) { - auto dsl = Json::parse(dsl_str); - auto bool_dsl = dsl.at("bool"); - auto predicate = ParseAnyNode(bool_dsl); - Assert(vector_node_opt_.has_value()); - auto vec_node = std::move(vector_node_opt_).value(); - if (predicate != nullptr) { - vec_node->predicate_ = std::move(predicate); - } - auto plan = std::make_unique(schema); + auto dsl = nlohmann::json::parse(dsl_str); + nlohmann::json vec_pack; + std::optional> predicate; + // top level + auto& bool_dsl = dsl.at("bool"); + if (bool_dsl.contains("must")) { + auto& packs = bool_dsl.at("must"); + Assert(packs.is_array()); + for (auto& pack : packs) { + if (pack.contains("vector")) { + auto& out_body = pack.at("vector"); + plan->plan_node_ = ParseVecNode(out_body); + } else if (pack.contains("term")) { + AssertInfo(!predicate, "unsupported complex DSL"); + auto& out_body = pack.at("term"); + predicate = ParseTermNode(out_body); + } else if (pack.contains("range")) { + AssertInfo(!predicate, "unsupported complex DSL"); + auto& out_body = pack.at("range"); + predicate = ParseRangeNode(out_body); + } else { + PanicInfo("unsupported node"); + } + } + AssertInfo(plan->plan_node_, "vector node not found"); + } else if (bool_dsl.contains("vector")) { + auto& out_body = bool_dsl.at("vector"); + plan->plan_node_ = ParseVecNode(out_body); + Assert(plan->plan_node_); + } else { + PanicInfo("Unsupported DSL"); + } + plan->plan_node_->predicate_ = std::move(predicate); plan->tag2field_ = std::move(tag2field_); - plan->plan_node_ = std::move(vec_node); + // TODO: target_entry parser + // if schema autoid is true, + // prepend target_entries_ with row_id + // else + // with primary_key + // return plan; } -ExprPtr +std::unique_ptr Parser::ParseTermNode(const Json& out_body) { Assert(out_body.size() == 1); auto out_iter = out_body.begin(); @@ -216,7 +221,7 @@ Parser::ParseVecNode(const Json& out_body) { } template -ExprPtr +std::unique_ptr Parser::ParseTermNodeImpl(const std::string& field_name, const Json& body) { auto expr = std::make_unique>(); auto data_type = schema[field_name].get_data_type(); @@ -244,7 +249,7 @@ Parser::ParseTermNodeImpl(const std::string& field_name, const Json& body) { } template -ExprPtr +std::unique_ptr Parser::ParseRangeNodeImpl(const std::string& field_name, const Json& body) { auto expr = std::make_unique>(); auto data_type = schema[field_name].get_data_type(); @@ -273,6 +278,12 @@ Parser::ParseRangeNodeImpl(const std::string& field_name, const Json& body) { return expr; } +std::unique_ptr +CreatePlan(const Schema& schema, const std::string& dsl_str) { + auto plan = Parser::CreatePlan(schema, dsl_str); + return plan; +} + std::unique_ptr ParsePlaceholderGroup(const Plan* plan, const std::string& blob) { namespace ser = milvus::proto::service; @@ -302,150 +313,6 @@ ParsePlaceholderGroup(const Plan* plan, const std::string& blob) { return result; } -std::unique_ptr -CreatePlan(const Schema& schema, const std::string& dsl_str) { - auto plan = Parser(schema).CreatePlanImpl(dsl_str); - return plan; -} - -std::vector -Parser::ParseItemList(const Json& body) { - std::vector results; - if (body.is_object()) { - // only one item; - auto new_entry = ParseAnyNode(body); - results.emplace_back(std::move(new_entry)); - } else { - // item array - Assert(body.is_array()); - for (auto& item : body) { - auto new_entry = ParseAnyNode(item); - results.emplace_back(std::move(new_entry)); - } - } - auto old_size = results.size(); - - auto new_end = std::remove_if(results.begin(), results.end(), [](const ExprPtr& x) { return x == nullptr; }); - - results.resize(new_end - results.begin()); - - return results; -} - -ExprPtr -Parser::ParseAnyNode(const Json& out_body) { - Assert(out_body.is_object()); - Assert(out_body.size() == 1); - - auto out_iter = out_body.begin(); - - auto key = out_iter.key(); - auto body = out_iter.value(); - - if (key == "must") { - return ParseMustNode(body); - } else if (key == "should") { - return ParseShouldNode(body); - } else if (key == "should_not") { - return ParseShouldNotNode(body); - } else if (key == "range") { - return ParseRangeNode(body); - } else if (key == "term") { - return ParseTermNode(body); - } else if (key == "vector") { - auto vec_node = ParseVecNode(body); - Assert(!vector_node_opt_.has_value()); - vector_node_opt_ = std::move(vec_node); - return nullptr; - } else { - PanicInfo("unsupported key: " + key); - } -} - -template -static ExprPtr -ConstructTree(Merger merger, std::vector item_list) { - if (item_list.size() == 0) { - return nullptr; - } - - if (item_list.size() == 1) { - return std::move(item_list[0]); - } - - // Note: use deque to construct a binary tree - // Op - // / \ - // Op Op - // | \ | \ - // A B C D - std::deque binary_queue; - for (auto& item : item_list) { - Assert(item != nullptr); - binary_queue.push_back(std::move(item)); - } - while (binary_queue.size() > 1) { - auto left = std::move(binary_queue.front()); - binary_queue.pop_front(); - auto right = std::move(binary_queue.front()); - binary_queue.pop_front(); - binary_queue.push_back(merger(std::move(left), std::move(right))); - } - Assert(binary_queue.size() == 1); - return std::move(binary_queue.front()); -} - -ExprPtr -Parser::ParseMustNode(const Json& body) { - auto item_list = ParseItemList(body); - auto merger = [](ExprPtr left, ExprPtr right) { - using OpType = BoolBinaryExpr::OpType; - auto res = std::make_unique(); - res->op_type_ = OpType::LogicalAnd; - res->left_ = std::move(left); - res->right_ = std::move(right); - return res; - }; - return ConstructTree(merger, std::move(item_list)); -} - -ExprPtr -Parser::ParseShouldNode(const Json& body) { - auto item_list = ParseItemList(body); - Assert(item_list.size() >= 1); - auto merger = [](ExprPtr left, ExprPtr right) { - using OpType = BoolBinaryExpr::OpType; - auto res = std::make_unique(); - res->op_type_ = OpType::LogicalOr; - res->left_ = std::move(left); - res->right_ = std::move(right); - return res; - }; - return ConstructTree(merger, std::move(item_list)); -} - -ExprPtr -Parser::ParseShouldNotNode(const Json& body) { - auto item_list = ParseItemList(body); - Assert(item_list.size() >= 1); - auto merger = [](ExprPtr left, ExprPtr right) { - using OpType = BoolBinaryExpr::OpType; - auto res = std::make_unique(); - res->op_type_ = OpType::LogicalAnd; - res->left_ = std::move(left); - res->right_ = std::move(right); - return res; - }; - auto subtree = ConstructTree(merger, std::move(item_list)); - - using OpType = BoolUnaryExpr::OpType; - auto res = std::make_unique(); - res->op_type_ = OpType::LogicalNot; - res->child_ = std::move(subtree); - - return res; -} - int64_t GetTopK(const Plan* plan) { return plan->plan_node_->query_info_.topK_; diff --git a/internal/core/src/query/visitors/ExecExprVisitor.cpp b/internal/core/src/query/visitors/ExecExprVisitor.cpp index e45969c38c..35d6bd2849 100644 --- a/internal/core/src/query/visitors/ExecExprVisitor.cpp +++ b/internal/core/src/query/visitors/ExecExprVisitor.cpp @@ -67,7 +67,6 @@ ExecExprVisitor::visit(BoolUnaryExpr& expr) { switch (expr.op_type_) { case OpType::LogicalNot: { chunk.flip(); - break; } default: { PanicInfo("Invalid OpType"); diff --git a/internal/core/unittest/test_expr.cpp b/internal/core/unittest/test_expr.cpp index e7421dfd39..74744d3c78 100644 --- a/internal/core/unittest/test_expr.cpp +++ b/internal/core/unittest/test_expr.cpp @@ -410,104 +410,3 @@ TEST(Expr, TestTerm) { } } } - -TEST(Expr, TestSimpleDsl) { - using namespace milvus::query; - using namespace milvus::segcore; - - auto vec_dsl = Json::parse(R"( - { - "vector": { - "fakevec": { - "metric_type": "L2", - "params": { - "nprobe": 10 - }, - "query": "$0", - "topk": 10 - } - } - } -)"); - - int N = 32; - auto get_item = [&](int base, int bit = 1) { - std::vector terms; - // note: random gen range is [0, 2N) - for (int i = 0; i < N * 2; ++i) { - if (((i >> base) & 0x1) == bit) { - terms.push_back(i); - } - } - Json s; - s["term"]["age"]["values"] = terms; - return s; - }; - // std::cout << get_item(0).dump(-2); - // std::cout << vec_dsl.dump(-2); - std::vector>> testcases; - { - Json dsl; - dsl["must"] = Json::array({vec_dsl, get_item(0), get_item(1), get_item(2, 0), get_item(3)}); - testcases.emplace_back(dsl, [](int x) { return (x & 0b1111) == 0b1011; }); - } - - { - Json dsl; - Json sub_dsl; - sub_dsl["must"] = Json::array({get_item(0), get_item(1), get_item(2, 0), get_item(3)}); - dsl["must"] = Json::array({sub_dsl, vec_dsl}); - testcases.emplace_back(dsl, [](int x) { return (x & 0b1111) == 0b1011; }); - } - - { - Json dsl; - Json sub_dsl; - sub_dsl["should"] = Json::array({get_item(0), get_item(1), get_item(2, 0), get_item(3)}); - dsl["must"] = Json::array({sub_dsl, vec_dsl}); - testcases.emplace_back(dsl, [](int x) { return !!((x & 0b1111) ^ 0b0100); }); - } - - { - Json dsl; - Json sub_dsl; - sub_dsl["should_not"] = Json::array({get_item(0), get_item(1), get_item(2, 0), get_item(3)}); - dsl["must"] = Json::array({sub_dsl, vec_dsl}); - testcases.emplace_back(dsl, [](int x) { return (x & 0b1111) != 0b1011; }); - } - - auto schema = std::make_shared(); - schema->AddField("fakevec", DataType::VECTOR_FLOAT, 16, MetricType::METRIC_L2); - schema->AddField("age", DataType::INT32); - - auto seg = CreateSegment(schema); - std::vector age_col; - int num_iters = 100; - for (int iter = 0; iter < num_iters; ++iter) { - auto raw_data = DataGen(schema, N, iter); - auto new_age_col = raw_data.get_col(1); - age_col.insert(age_col.end(), new_age_col.begin(), new_age_col.end()); - seg->PreInsert(N); - seg->Insert(iter * N, N, raw_data.row_ids_.data(), raw_data.timestamps_.data(), raw_data.raw_); - } - - auto seg_promote = dynamic_cast(seg.get()); - ExecExprVisitor visitor(*seg_promote); - for (auto [clause, ref_func] : testcases) { - Json dsl; - dsl["bool"] = clause; - // std::cout << dsl.dump(2); - auto plan = CreatePlan(*schema, dsl.dump()); - auto final = visitor.call_child(*plan->plan_node_->predicate_.value()); - EXPECT_EQ(final.size(), upper_div(N * num_iters, DefaultElementPerChunk)); - - for (int i = 0; i < N * num_iters; ++i) { - auto vec_id = i / DefaultElementPerChunk; - auto offset = i % DefaultElementPerChunk; - bool ans = final[vec_id][offset]; - auto val = age_col[i]; - auto ref = ref_func(val); - ASSERT_EQ(ans, ref) << clause << "@" << i << "!!" << val; - } - } -} diff --git a/internal/master/grpc_service.go b/internal/master/grpc_service.go index e7071424b8..b66949171c 100644 --- a/internal/master/grpc_service.go +++ b/internal/master/grpc_service.go @@ -437,7 +437,7 @@ func (s *Master) AllocID(ctx context.Context, request *internalpb.IDRequest) (*i } func (s *Master) AssignSegmentID(ctx context.Context, request *internalpb.AssignSegIDRequest) (*internalpb.AssignSegIDResponse, error) { - segInfos, err := s.segmentMgr.AssignSegmentID(request.GetPerChannelReq()) + segInfos, err := s.segmentManager.AssignSegment(request.GetPerChannelReq()) if err != nil { return &internalpb.AssignSegIDResponse{ Status: &commonpb.Status{ErrorCode: commonpb.ErrorCode_UNEXPECTED_ERROR}, diff --git a/internal/master/master.go b/internal/master/master.go index 555ccc3398..36b2b98c86 100644 --- a/internal/master/master.go +++ b/internal/master/master.go @@ -54,7 +54,9 @@ type Master struct { startCallbacks []func() closeCallbacks []func() - segmentMgr *SegmentManager + segmentManager *SegmentManager + segmentAssigner *SegmentAssigner + statProcessor *StatsProcessor segmentStatusMsg ms.MsgStream //id allocator @@ -137,6 +139,11 @@ func CreateServer(ctx context.Context) (*Master, error) { pulsarK2SStream.CreatePulsarProducers(Params.K2SChannelNames) tsMsgProducer.SetK2sSyncStream(pulsarK2SStream) + proxyTtBarrierWatcher := make(chan *ms.TimeTickMsg, 1024) + writeNodeTtBarrierWatcher := make(chan *ms.TimeTickMsg, 1024) + tsMsgProducer.WatchProxyTtBarrier(proxyTtBarrierWatcher) + tsMsgProducer.WatchWriteNodeTtBarrier(writeNodeTtBarrierWatcher) + // stats msg stream statsMs := ms.NewPulsarMsgStream(ctx, 1024) statsMs.SetPulsarClient(pulsarAddr) @@ -169,9 +176,23 @@ func CreateServer(ctx context.Context) (*Master, error) { m.scheduler.SetDDMsgStream(pulsarDDStream) m.scheduler.SetIDAllocator(func() (UniqueID, error) { return m.idAllocator.AllocOne() }) - m.segmentMgr = NewSegmentManager(metakv, + m.segmentAssigner = NewSegmentAssigner(ctx, metakv, + func() (Timestamp, error) { return m.tsoAllocator.AllocOne() }, + proxyTtBarrierWatcher, + ) + m.segmentManager, err = NewSegmentManager(ctx, metakv, func() (UniqueID, error) { return m.idAllocator.AllocOne() }, func() (Timestamp, error) { return m.tsoAllocator.AllocOne() }, + writeNodeTtBarrierWatcher, + &MockFlushScheduler{}, // todo replace mock with real flush scheduler + m.segmentAssigner) + + if err != nil { + return nil, err + } + + m.statProcessor = NewStatsProcessor(metakv, + func() (Timestamp, error) { return m.tsoAllocator.AllocOne() }, ) m.grpcServer = grpc.NewServer() @@ -199,7 +220,8 @@ func (s *Master) Close() { log.Print("closing server") s.stopServerLoop() - + s.segmentAssigner.Close() + s.segmentManager.Close() if s.kvBase != nil { s.kvBase.Close() } @@ -226,6 +248,8 @@ func (s *Master) Run(grpcPort int64) error { if err := s.startServerLoop(s.ctx, grpcPort); err != nil { return err } + s.segmentAssigner.Start() + s.segmentManager.Start() atomic.StoreInt64(&s.isServing, 1) // Run callbacks @@ -268,7 +292,7 @@ func (s *Master) startServerLoop(ctx context.Context, grpcPort int64) error { } s.serverLoopWg.Add(1) - go s.segmentStatisticsLoop() + go s.statisticsLoop() s.serverLoopWg.Add(1) go s.tsLoop() @@ -348,7 +372,7 @@ func (s *Master) tsLoop() { } } -func (s *Master) segmentStatisticsLoop() { +func (s *Master) statisticsLoop() { defer s.serverLoopWg.Done() defer s.segmentStatusMsg.Close() ctx, cancel := context.WithCancel(s.serverLoopCtx) @@ -357,7 +381,7 @@ func (s *Master) segmentStatisticsLoop() { for { select { case msg := <-s.segmentStatusMsg.Chan(): - err := s.segmentMgr.HandleQueryNodeMsgPack(msg) + err := s.statProcessor.ProcessQueryNodeStats(msg) if err != nil { log.Println(err) } diff --git a/internal/master/persistent_scheduler.go b/internal/master/persistent_scheduler.go new file mode 100644 index 0000000000..bc4884ced1 --- /dev/null +++ b/internal/master/persistent_scheduler.go @@ -0,0 +1,30 @@ +package master + +type persistenceScheduler interface { + Enqueue(interface{}) error + schedule(interface{}) error + scheduleLoop() + + Start() error + Close() +} +type MockFlushScheduler struct { +} + +func (m *MockFlushScheduler) Enqueue(i interface{}) error { + return nil +} + +func (m *MockFlushScheduler) schedule(i interface{}) error { + return nil +} + +func (m *MockFlushScheduler) scheduleLoop() { +} + +func (m *MockFlushScheduler) Start() error { + return nil +} + +func (m *MockFlushScheduler) Close() { +} diff --git a/internal/master/segment_assigner.go b/internal/master/segment_assigner.go new file mode 100644 index 0000000000..07a4fb4820 --- /dev/null +++ b/internal/master/segment_assigner.go @@ -0,0 +1,187 @@ +package master + +import ( + "context" + "log" + "sync" + "time" + + "github.com/zilliztech/milvus-distributed/internal/util/tsoutil" + + "github.com/zilliztech/milvus-distributed/internal/errors" + + ms "github.com/zilliztech/milvus-distributed/internal/msgstream" +) + +type Assignment struct { + rowNums int + expireTime Timestamp +} + +type Status struct { + total int + lastExpireTime Timestamp + assignments []*Assignment +} + +type SegmentAssigner struct { + mt *metaTable + segmentStatus map[UniqueID]*Status //segment id -> status + + globalTSOAllocator func() (Timestamp, error) + segmentExpireDuration int64 + + proxyTimeSyncChan chan *ms.TimeTickMsg + ctx context.Context + cancel context.CancelFunc + waitGroup sync.WaitGroup + mu sync.Mutex +} + +func (assigner *SegmentAssigner) OpenSegment(segmentID UniqueID, numRows int) error { + assigner.mu.Lock() + defer assigner.mu.Unlock() + if _, ok := assigner.segmentStatus[segmentID]; ok { + return errors.Errorf("can not reopen segment %d", segmentID) + } + + newStatus := &Status{ + total: numRows, + assignments: make([]*Assignment, 0), + } + assigner.segmentStatus[segmentID] = newStatus + return nil +} + +func (assigner *SegmentAssigner) CloseSegment(segmentID UniqueID) error { + assigner.mu.Lock() + defer assigner.mu.Unlock() + if _, ok := assigner.segmentStatus[segmentID]; !ok { + return errors.Errorf("can not find segment %d", segmentID) + } + + delete(assigner.segmentStatus, segmentID) + return nil +} + +func (assigner *SegmentAssigner) Assign(segmentID UniqueID, numRows int) (bool, error) { + assigner.mu.Lock() + defer assigner.mu.Unlock() + status, ok := assigner.segmentStatus[segmentID] + if !ok { + return false, errors.Errorf("segment %d is not opened", segmentID) + } + + allocated, err := assigner.totalOfAssignments(segmentID) + if err != nil { + return false, err + } + + segMeta, err := assigner.mt.GetSegmentByID(segmentID) + if err != nil { + return false, err + } + free := status.total - int(segMeta.NumRows) - allocated + if numRows > free { + return false, nil + } + + ts, err := assigner.globalTSOAllocator() + if err != nil { + return false, err + } + physicalTs, logicalTs := tsoutil.ParseTS(ts) + expirePhysicalTs := physicalTs.Add(time.Duration(assigner.segmentExpireDuration)) + expireTs := tsoutil.ComposeTS(expirePhysicalTs.UnixNano()/int64(time.Millisecond), int64(logicalTs)) + status.lastExpireTime = expireTs + status.assignments = append(status.assignments, &Assignment{ + numRows, + ts, + }) + + return true, nil +} + +func (assigner *SegmentAssigner) CheckAssignmentExpired(segmentID UniqueID, timestamp Timestamp) (bool, error) { + assigner.mu.Lock() + defer assigner.mu.Unlock() + status, ok := assigner.segmentStatus[segmentID] + if !ok { + return false, errors.Errorf("can not find segment %d", segmentID) + } + + if timestamp >= status.lastExpireTime { + return true, nil + } + + return false, nil +} + +func (assigner *SegmentAssigner) Start() { + assigner.waitGroup.Add(1) + go assigner.startProxyTimeSync() +} + +func (assigner *SegmentAssigner) Close() { + assigner.cancel() + assigner.waitGroup.Wait() +} + +func (assigner *SegmentAssigner) startProxyTimeSync() { + defer assigner.waitGroup.Done() + for { + select { + case <-assigner.ctx.Done(): + log.Println("proxy time sync stopped") + return + case msg := <-assigner.proxyTimeSyncChan: + if err := assigner.syncProxyTimeStamp(msg.TimeTickMsg.Timestamp); err != nil { + log.Println("proxy time sync error: " + err.Error()) + } + } + } +} + +func (assigner *SegmentAssigner) totalOfAssignments(segmentID UniqueID) (int, error) { + if _, ok := assigner.segmentStatus[segmentID]; !ok { + return -1, errors.Errorf("can not find segment %d", segmentID) + } + + status := assigner.segmentStatus[segmentID] + res := 0 + for _, v := range status.assignments { + res += v.rowNums + } + return res, nil +} + +func (assigner *SegmentAssigner) syncProxyTimeStamp(timeTick Timestamp) error { + assigner.mu.Lock() + defer assigner.mu.Unlock() + for _, status := range assigner.segmentStatus { + for i := 0; i < len(status.assignments); { + if timeTick >= status.assignments[i].expireTime { + status.assignments[i] = status.assignments[len(status.assignments)-1] + status.assignments = status.assignments[:len(status.assignments)-1] + continue + } + i++ + } + } + + return nil +} + +func NewSegmentAssigner(ctx context.Context, metaTable *metaTable, + globalTSOAllocator func() (Timestamp, error), proxyTimeSyncChan chan *ms.TimeTickMsg) *SegmentAssigner { + assignCtx, cancel := context.WithCancel(ctx) + return &SegmentAssigner{ + mt: metaTable, + segmentStatus: make(map[UniqueID]*Status), + globalTSOAllocator: globalTSOAllocator, + segmentExpireDuration: Params.SegIDAssignExpiration, + proxyTimeSyncChan: proxyTimeSyncChan, + ctx: assignCtx, + cancel: cancel, + } +} diff --git a/internal/master/segment_manager.go b/internal/master/segment_manager.go index 3219c25bcf..74929fb74a 100644 --- a/internal/master/segment_manager.go +++ b/internal/master/segment_manager.go @@ -1,250 +1,191 @@ package master import ( + "context" + "log" "sync" - "time" + + "github.com/zilliztech/milvus-distributed/internal/util/typeutil" + + "github.com/zilliztech/milvus-distributed/internal/proto/etcdpb" "github.com/zilliztech/milvus-distributed/internal/errors" - "github.com/zilliztech/milvus-distributed/internal/msgstream" - "github.com/zilliztech/milvus-distributed/internal/proto/etcdpb" "github.com/zilliztech/milvus-distributed/internal/proto/internalpb" + + ms "github.com/zilliztech/milvus-distributed/internal/msgstream" ) type collectionStatus struct { - openedSegments []UniqueID + segments []*segmentStatus } - -type assignment struct { - MemSize int64 // bytes - AssignTime time.Time +type segmentStatus struct { + segmentID UniqueID + total int + closable bool } type channelRange struct { channelStart int32 channelEnd int32 } - -type segmentStatus struct { - assignments []*assignment -} - type SegmentManager struct { metaTable *metaTable - statsStream msgstream.MsgStream channelRanges []*channelRange - segmentStatus map[UniqueID]*segmentStatus // segment id to segment status collStatus map[UniqueID]*collectionStatus // collection id to collection status defaultSizePerRecord int64 - minimumAssignSize int64 segmentThreshold float64 segmentThresholdFactor float64 - segmentExpireDuration int64 numOfChannels int numOfQueryNodes int globalIDAllocator func() (UniqueID, error) globalTSOAllocator func() (Timestamp, error) mu sync.RWMutex + + assigner *SegmentAssigner + + writeNodeTimeSyncChan chan *ms.TimeTickMsg + flushScheduler persistenceScheduler + + ctx context.Context + cancel context.CancelFunc + waitGroup sync.WaitGroup } -func (segMgr *SegmentManager) HandleQueryNodeMsgPack(msgPack *msgstream.MsgPack) error { - segMgr.mu.Lock() - defer segMgr.mu.Unlock() - for _, msg := range msgPack.Msgs { - statsMsg, ok := msg.(*msgstream.QueryNodeStatsMsg) - if !ok { - return errors.Errorf("Type of message is not QueryNodeStatsMsg") - } +func (manager *SegmentManager) AssignSegment(segIDReq []*internalpb.SegIDRequest) ([]*internalpb.SegIDAssignment, error) { + manager.mu.Lock() + defer manager.mu.Unlock() - for _, segStat := range statsMsg.GetSegStats() { - err := segMgr.handleSegmentStat(segStat) - if err != nil { - return err - } - } - } - return nil -} - -func (segMgr *SegmentManager) handleSegmentStat(segStats *internalpb.SegmentStats) error { - if !segStats.GetRecentlyModified() { - return nil - } - segID := segStats.GetSegmentID() - segMeta, err := segMgr.metaTable.GetSegmentByID(segID) - if err != nil { - return err - } - segMeta.NumRows = segStats.NumRows - segMeta.MemSize = segStats.MemorySize - - if segStats.MemorySize > int64(segMgr.segmentThresholdFactor*segMgr.segmentThreshold) { - if err := segMgr.metaTable.UpdateSegment(segMeta); err != nil { - return err - } - return segMgr.closeSegment(segMeta) - } - return segMgr.metaTable.UpdateSegment(segMeta) -} - -func (segMgr *SegmentManager) closeSegment(segMeta *etcdpb.SegmentMeta) error { - if segMeta.GetCloseTime() == 0 { - // close the segment and remove from collStatus - collStatus, ok := segMgr.collStatus[segMeta.GetCollectionID()] - if ok { - openedSegments := collStatus.openedSegments - for i, openedSegID := range openedSegments { - if openedSegID == segMeta.SegmentID { - openedSegments[i] = openedSegments[len(openedSegments)-1] - collStatus.openedSegments = openedSegments[:len(openedSegments)-1] - break - } - } - } - ts, err := segMgr.globalTSOAllocator() - if err != nil { - return err - } - segMeta.CloseTime = ts - } - - err := segMgr.metaTable.CloseSegment(segMeta.SegmentID, segMeta.GetCloseTime()) - if err != nil { - return err - } - return nil -} - -func (segMgr *SegmentManager) AssignSegmentID(segIDReq []*internalpb.SegIDRequest) ([]*internalpb.SegIDAssignment, error) { - segMgr.mu.Lock() - defer segMgr.mu.Unlock() res := make([]*internalpb.SegIDAssignment, 0) + for _, req := range segIDReq { collName := req.CollName partitionTag := req.PartitionTag count := req.Count channelID := req.ChannelID - collMeta, err := segMgr.metaTable.GetCollectionByName(collName) + + collMeta, err := manager.metaTable.GetCollectionByName(collName) if err != nil { return nil, err } collID := collMeta.GetID() - if !segMgr.metaTable.HasCollection(collID) { - return nil, errors.Errorf("can not find collection with id=%d", collID) - } - if !segMgr.metaTable.HasPartition(collID, partitionTag) { + if !manager.metaTable.HasPartition(collID, partitionTag) { return nil, errors.Errorf("partition tag %s can not find in coll %d", partitionTag, collID) } - collStatus, ok := segMgr.collStatus[collID] - if !ok { - collStatus = &collectionStatus{ - openedSegments: make([]UniqueID, 0), - } - segMgr.collStatus[collID] = collStatus - } - assignInfo, err := segMgr.assignSegment(collName, collID, partitionTag, count, channelID, collStatus) + assignInfo, err := manager.assignSegment(collName, collID, partitionTag, count, channelID) if err != nil { return nil, err } + res = append(res, assignInfo) } + return res, nil } -func (segMgr *SegmentManager) assignSegment(collName string, collID UniqueID, partitionTag string, count uint32, channelID int32, - collStatus *collectionStatus) (*internalpb.SegIDAssignment, error) { - segmentThreshold := int64(segMgr.segmentThreshold) - for _, segID := range collStatus.openedSegments { - segMeta, _ := segMgr.metaTable.GetSegmentByID(segID) - if segMeta.GetCloseTime() != 0 || channelID < segMeta.GetChannelStart() || - channelID > segMeta.GetChannelEnd() || segMeta.PartitionTag != partitionTag { +func (manager *SegmentManager) assignSegment( + collName string, + collID UniqueID, + partitionTag string, + count uint32, + channelID int32) (*internalpb.SegIDAssignment, error) { + + collStatus, ok := manager.collStatus[collID] + if !ok { + collStatus = &collectionStatus{ + segments: make([]*segmentStatus, 0), + } + manager.collStatus[collID] = collStatus + } + for _, segStatus := range collStatus.segments { + if segStatus.closable { continue } - // check whether segment has enough mem size - assignedMem := segMgr.checkAssignedSegExpire(segID) - memSize := segMeta.MemSize - neededMemSize := segMgr.calNeededSize(memSize, segMeta.NumRows, int64(count)) - if memSize+assignedMem+neededMemSize <= segmentThreshold { - remainingSize := segmentThreshold - memSize - assignedMem - allocMemSize := segMgr.calAllocMemSize(neededMemSize, remainingSize) - segMgr.addAssignment(segID, allocMemSize) - return &internalpb.SegIDAssignment{ - SegID: segID, - ChannelID: channelID, - Count: uint32(segMgr.calNumRows(memSize, segMeta.NumRows, allocMemSize)), - CollName: collName, - PartitionTag: partitionTag, - }, nil + match, err := manager.isMatch(segStatus.segmentID, partitionTag, channelID) + if err != nil { + return nil, err } - } - neededMemSize := segMgr.defaultSizePerRecord * int64(count) - if neededMemSize > segmentThreshold { - return nil, errors.Errorf("request with count %d need about %d mem size which is larger than segment threshold", - count, neededMemSize) + if !match { + continue + } + + result, err := manager.assigner.Assign(segStatus.segmentID, int(count)) + if err != nil { + return nil, err + } + if !result { + continue + } + + return &internalpb.SegIDAssignment{ + SegID: segStatus.segmentID, + ChannelID: channelID, + Count: count, + CollName: collName, + PartitionTag: partitionTag, + }, nil + } - segMeta, err := segMgr.openNewSegment(channelID, collID, partitionTag) + total, err := manager.estimateTotalRows(collName) + if err != nil { + return nil, err + } + if int(count) > total { + return nil, errors.Errorf("request count %d is larger than total rows %d", count, total) + } + + id, err := manager.openNewSegment(channelID, collID, partitionTag, total) if err != nil { return nil, err } - allocMemSize := segMgr.calAllocMemSize(neededMemSize, segmentThreshold) - segMgr.addAssignment(segMeta.SegmentID, allocMemSize) + result, err := manager.assigner.Assign(id, int(count)) + if err != nil { + return nil, err + } + if !result { + return nil, errors.Errorf("assign failed for segment %d", id) + } return &internalpb.SegIDAssignment{ - SegID: segMeta.SegmentID, + SegID: id, ChannelID: channelID, - Count: uint32(segMgr.calNumRows(0, 0, allocMemSize)), + Count: count, CollName: collName, PartitionTag: partitionTag, }, nil } -func (segMgr *SegmentManager) addAssignment(segID UniqueID, allocSize int64) { - segStatus := segMgr.segmentStatus[segID] - segStatus.assignments = append(segStatus.assignments, &assignment{ - MemSize: allocSize, - AssignTime: time.Now(), - }) +func (manager *SegmentManager) isMatch(segmentID UniqueID, partitionTag string, channelID int32) (bool, error) { + segMeta, err := manager.metaTable.GetSegmentByID(segmentID) + if err != nil { + return false, err + } + + if channelID < segMeta.GetChannelStart() || + channelID > segMeta.GetChannelEnd() || segMeta.PartitionTag != partitionTag { + return false, nil + } + return true, nil } -func (segMgr *SegmentManager) calNeededSize(memSize int64, numRows int64, count int64) int64 { - var avgSize int64 - if memSize == 0 || numRows == 0 || memSize/numRows == 0 { - avgSize = segMgr.defaultSizePerRecord - } else { - avgSize = memSize / numRows +func (manager *SegmentManager) estimateTotalRows(collName string) (int, error) { + collMeta, err := manager.metaTable.GetCollectionByName(collName) + if err != nil { + return -1, err } - return avgSize * count + sizePerRecord, err := typeutil.EstimateSizePerRecord(collMeta.Schema) + if err != nil { + return -1, err + } + return int(manager.segmentThreshold / float64(sizePerRecord)), nil } -func (segMgr *SegmentManager) calAllocMemSize(neededSize int64, remainSize int64) int64 { - if neededSize > remainSize { - return 0 - } - if remainSize < segMgr.minimumAssignSize { - return remainSize - } - if neededSize < segMgr.minimumAssignSize { - return segMgr.minimumAssignSize - } - return neededSize -} - -func (segMgr *SegmentManager) calNumRows(memSize int64, numRows int64, allocMemSize int64) int64 { - var avgSize int64 - if memSize == 0 || numRows == 0 || memSize/numRows == 0 { - avgSize = segMgr.defaultSizePerRecord - } else { - avgSize = memSize / numRows - } - return allocMemSize / avgSize -} - -func (segMgr *SegmentManager) openNewSegment(channelID int32, collID UniqueID, partitionTag string) (*etcdpb.SegmentMeta, error) { +func (manager *SegmentManager) openNewSegment(channelID int32, collID UniqueID, partitionTag string, numRows int) (UniqueID, error) { // find the channel range channelStart, channelEnd := int32(-1), int32(-1) - for _, r := range segMgr.channelRanges { + for _, r := range manager.channelRanges { if channelID >= r.channelStart && channelID <= r.channelEnd { channelStart = r.channelStart channelEnd = r.channelEnd @@ -252,18 +193,19 @@ func (segMgr *SegmentManager) openNewSegment(channelID int32, collID UniqueID, p } } if channelStart == -1 { - return nil, errors.Errorf("can't find the channel range which contains channel %d", channelID) + return -1, errors.Errorf("can't find the channel range which contains channel %d", channelID) } - newID, err := segMgr.globalIDAllocator() + newID, err := manager.globalIDAllocator() if err != nil { - return nil, err + return -1, err } - openTime, err := segMgr.globalTSOAllocator() + openTime, err := manager.globalTSOAllocator() if err != nil { - return nil, err + return -1, err } - newSegMeta := &etcdpb.SegmentMeta{ + + err = manager.metaTable.AddSegment(&etcdpb.SegmentMeta{ SegmentID: newID, CollectionID: collID, PartitionTag: partitionTag, @@ -272,51 +214,119 @@ func (segMgr *SegmentManager) openNewSegment(channelID int32, collID UniqueID, p OpenTime: openTime, NumRows: 0, MemSize: 0, - } - - err = segMgr.metaTable.AddSegment(newSegMeta) + }) if err != nil { - return nil, err + return -1, err } - segMgr.segmentStatus[newID] = &segmentStatus{ - assignments: make([]*assignment, 0), + + err = manager.assigner.OpenSegment(newID, numRows) + if err != nil { + return -1, err } - collStatus := segMgr.collStatus[collID] - collStatus.openedSegments = append(collStatus.openedSegments, newSegMeta.SegmentID) - return newSegMeta, nil + + segStatus := &segmentStatus{ + segmentID: newID, + total: numRows, + closable: false, + } + + collStatus := manager.collStatus[collID] + collStatus.segments = append(collStatus.segments, segStatus) + + return newID, nil } -// checkAssignedSegExpire check the expire time of assignments and return the total sum of assignments that are not expired. -func (segMgr *SegmentManager) checkAssignedSegExpire(segID UniqueID) int64 { - segStatus := segMgr.segmentStatus[segID] - assignments := segStatus.assignments - result := int64(0) - i := 0 - for i < len(assignments) { - assign := assignments[i] - if time.Since(assign.AssignTime) >= time.Duration(segMgr.segmentExpireDuration)*time.Millisecond { - assignments[i] = assignments[len(assignments)-1] - assignments = assignments[:len(assignments)-1] - continue +func (manager *SegmentManager) Start() { + manager.waitGroup.Add(1) + go manager.startWriteNodeTimeSync() +} + +func (manager *SegmentManager) Close() { + manager.cancel() + manager.waitGroup.Wait() +} + +func (manager *SegmentManager) startWriteNodeTimeSync() { + defer manager.waitGroup.Done() + for { + select { + case <-manager.ctx.Done(): + log.Println("write node time sync stopped") + return + case msg := <-manager.writeNodeTimeSyncChan: + if err := manager.syncWriteNodeTimestamp(msg.TimeTickMsg.Timestamp); err != nil { + log.Println("write node time sync error: " + err.Error()) + } } - result += assign.MemSize - i++ } - segStatus.assignments = assignments - return result } -func (segMgr *SegmentManager) createChannelRanges() error { - div, rem := segMgr.numOfChannels/segMgr.numOfQueryNodes, segMgr.numOfChannels%segMgr.numOfQueryNodes - for i, j := 0, 0; i < segMgr.numOfChannels; j++ { +func (manager *SegmentManager) syncWriteNodeTimestamp(timeTick Timestamp) error { + manager.mu.Lock() + defer manager.mu.Unlock() + for _, status := range manager.collStatus { + for i, segStatus := range status.segments { + if !segStatus.closable { + closable, err := manager.judgeSegmentClosable(segStatus) + if err != nil { + return err + } + segStatus.closable = closable + if !segStatus.closable { + continue + } + } + + isExpired, err := manager.assigner.CheckAssignmentExpired(segStatus.segmentID, timeTick) + if err != nil { + return err + } + if !isExpired { + continue + } + status.segments = append(status.segments[:i], status.segments[i+1:]...) + ts, err := manager.globalTSOAllocator() + if err != nil { + return err + } + if err = manager.metaTable.CloseSegment(segStatus.segmentID, ts); err != nil { + return err + } + if err = manager.assigner.CloseSegment(segStatus.segmentID); err != nil { + return err + } + if err = manager.flushScheduler.Enqueue(segStatus.segmentID); err != nil { + return err + } + } + } + + return nil +} + +func (manager *SegmentManager) judgeSegmentClosable(status *segmentStatus) (bool, error) { + segMeta, err := manager.metaTable.GetSegmentByID(status.segmentID) + if err != nil { + return false, err + } + + if segMeta.NumRows >= int64(manager.segmentThresholdFactor*float64(status.total)) { + return true, nil + } + return false, nil +} + +func (manager *SegmentManager) initChannelRanges() error { + div, rem := manager.numOfChannels/manager.numOfQueryNodes, manager.numOfChannels%manager.numOfQueryNodes + for i, j := 0, 0; i < manager.numOfChannels; j++ { if j < rem { - segMgr.channelRanges = append(segMgr.channelRanges, &channelRange{ + manager.channelRanges = append(manager.channelRanges, &channelRange{ channelStart: int32(i), channelEnd: int32(i + div), }) i += div + 1 } else { - segMgr.channelRanges = append(segMgr.channelRanges, &channelRange{ + manager.channelRanges = append(manager.channelRanges, &channelRange{ channelStart: int32(i), channelEnd: int32(i + div - 1), }) @@ -325,26 +335,38 @@ func (segMgr *SegmentManager) createChannelRanges() error { } return nil } - -func NewSegmentManager(meta *metaTable, +func NewSegmentManager(ctx context.Context, + meta *metaTable, globalIDAllocator func() (UniqueID, error), globalTSOAllocator func() (Timestamp, error), -) *SegmentManager { - segMgr := &SegmentManager{ + syncWriteNodeChan chan *ms.TimeTickMsg, + scheduler persistenceScheduler, + assigner *SegmentAssigner) (*SegmentManager, error) { + + assignerCtx, cancel := context.WithCancel(ctx) + segAssigner := &SegmentManager{ metaTable: meta, channelRanges: make([]*channelRange, 0), - segmentStatus: make(map[UniqueID]*segmentStatus), collStatus: make(map[UniqueID]*collectionStatus), segmentThreshold: Params.SegmentSize * 1024 * 1024, segmentThresholdFactor: Params.SegmentSizeFactor, - segmentExpireDuration: Params.SegIDAssignExpiration, - minimumAssignSize: Params.MinSegIDAssignCnt * Params.DefaultRecordSize, defaultSizePerRecord: Params.DefaultRecordSize, numOfChannels: Params.TopicNum, numOfQueryNodes: Params.QueryNodeNum, globalIDAllocator: globalIDAllocator, globalTSOAllocator: globalTSOAllocator, + + assigner: assigner, + writeNodeTimeSyncChan: syncWriteNodeChan, + flushScheduler: scheduler, + + ctx: assignerCtx, + cancel: cancel, } - segMgr.createChannelRanges() - return segMgr + + if err := segAssigner.initChannelRanges(); err != nil { + return nil, err + } + + return segAssigner, nil } diff --git a/internal/master/segment_manager_test.go b/internal/master/segment_manager_test.go index b357ddeb43..2e6e17c5bd 100644 --- a/internal/master/segment_manager_test.go +++ b/internal/master/segment_manager_test.go @@ -2,15 +2,12 @@ package master import ( "context" - "log" "sync/atomic" "testing" "time" "github.com/golang/protobuf/proto" "github.com/stretchr/testify/assert" - "github.com/zilliztech/milvus-distributed/internal/errors" - "github.com/zilliztech/milvus-distributed/internal/kv" etcdkv "github.com/zilliztech/milvus-distributed/internal/kv/etcd" "github.com/zilliztech/milvus-distributed/internal/msgstream" "github.com/zilliztech/milvus-distributed/internal/proto/commonpb" @@ -25,61 +22,11 @@ import ( "google.golang.org/grpc" ) -var mt *metaTable -var segMgr *SegmentManager -var collName = "coll_segmgr_test" -var collID = int64(1001) -var partitionTag = "test" -var kvBase kv.TxnBase -var master *Master -var masterCancelFunc context.CancelFunc +func TestSegmentManager_AssignSegment(t *testing.T) { + ctx, cancelFunc := context.WithCancel(context.TODO()) + defer cancelFunc() -func setup() { Init() - etcdAddress := Params.EtcdAddress - - cli, err := clientv3.New(clientv3.Config{Endpoints: []string{etcdAddress}}) - if err != nil { - panic(err) - } - rootPath := "/test/root" - ctx, cancel := context.WithCancel(context.TODO()) - defer cancel() - _, err = cli.Delete(ctx, rootPath, clientv3.WithPrefix()) - if err != nil { - panic(err) - } - kvBase = etcdkv.NewEtcdKV(cli, rootPath) - tmpMt, err := NewMetaTable(kvBase) - if err != nil { - panic(err) - } - mt = tmpMt - if mt.HasCollection(collID) { - err := mt.DeleteCollection(collID) - if err != nil { - panic(err) - } - } - err = mt.AddCollection(&pb.CollectionMeta{ - ID: collID, - Schema: &schemapb.CollectionSchema{ - Name: collName, - }, - CreateTime: 0, - SegmentIDs: []UniqueID{}, - PartitionTags: []string{}, - }) - if err != nil { - panic(err) - } - err = mt.AddPartition(collID, partitionTag) - if err != nil { - panic(err) - } - - var cnt int64 - Params.TopicNum = 5 Params.QueryNodeNum = 3 Params.SegmentSize = 536870912 / 1024 / 1024 @@ -87,151 +34,143 @@ func setup() { Params.DefaultRecordSize = 1024 Params.MinSegIDAssignCnt = 1048576 / 1024 Params.SegIDAssignExpiration = 2000 + etcdAddress := Params.EtcdAddress + cli, err := clientv3.New(clientv3.Config{Endpoints: []string{etcdAddress}}) + assert.Nil(t, err) + rootPath := "/test/root" + _, err = cli.Delete(ctx, rootPath, clientv3.WithPrefix()) + assert.Nil(t, err) - segMgr = NewSegmentManager(mt, - func() (UniqueID, error) { - val := atomic.AddInt64(&cnt, 1) - return val, nil + kvBase := etcdkv.NewEtcdKV(cli, rootPath) + defer kvBase.Close() + mt, err := NewMetaTable(kvBase) + assert.Nil(t, err) + + collName := "segmgr_test_coll" + var collID int64 = 1001 + partitionTag := "test_part" + schema := &schemapb.CollectionSchema{ + Name: collName, + Fields: []*schemapb.FieldSchema{ + {FieldID: 1, Name: "f1", IsPrimaryKey: false, DataType: schemapb.DataType_INT32}, + {FieldID: 2, Name: "f2", IsPrimaryKey: false, DataType: schemapb.DataType_VECTOR_FLOAT, TypeParams: []*commonpb.KeyValuePair{ + {Key: "dim", Value: "128"}, + }}, }, - func() (Timestamp, error) { - val := atomic.AddInt64(&cnt, 1) - phy := time.Now().UnixNano() / int64(time.Millisecond) - ts := tsoutil.ComposeTS(phy, val) - return ts, nil - }, - ) -} - -func teardown() { - err := mt.DeleteCollection(collID) - if err != nil { - log.Fatalf(err.Error()) } - kvBase.Close() -} + err = mt.AddCollection(&pb.CollectionMeta{ + ID: collID, + Schema: schema, + CreateTime: 0, + SegmentIDs: []UniqueID{}, + PartitionTags: []string{}, + }) + assert.Nil(t, err) + err = mt.AddPartition(collID, partitionTag) + assert.Nil(t, err) -func TestSegmentManager_AssignSegmentID(t *testing.T) { - setup() - defer teardown() - reqs := []*internalpb.SegIDRequest{ - {CollName: collName, PartitionTag: partitionTag, Count: 25000, ChannelID: 0}, - {CollName: collName, PartitionTag: partitionTag, Count: 10000, ChannelID: 1}, - {CollName: collName, PartitionTag: partitionTag, Count: 30000, ChannelID: 2}, - {CollName: collName, PartitionTag: partitionTag, Count: 25000, ChannelID: 3}, - {CollName: collName, PartitionTag: partitionTag, Count: 10000, ChannelID: 4}, + var cnt int64 + globalIDAllocator := func() (UniqueID, error) { + val := atomic.AddInt64(&cnt, 1) + return val, nil + } + globalTsoAllocator := func() (Timestamp, error) { + val := atomic.AddInt64(&cnt, 1) + phy := time.Now().UnixNano() / int64(time.Millisecond) + ts := tsoutil.ComposeTS(phy, val) + return ts, nil + } + syncWriteChan := make(chan *msgstream.TimeTickMsg) + syncProxyChan := make(chan *msgstream.TimeTickMsg) + + segAssigner := NewSegmentAssigner(ctx, mt, globalTsoAllocator, syncProxyChan) + mockScheduler := &MockFlushScheduler{} + segManager, err := NewSegmentManager(ctx, mt, globalIDAllocator, globalTsoAllocator, syncWriteChan, mockScheduler, segAssigner) + assert.Nil(t, err) + + segManager.Start() + defer segManager.Close() + sizePerRecord, err := typeutil.EstimateSizePerRecord(schema) + assert.Nil(t, err) + maxCount := uint32(Params.SegmentSize * 1024 * 1024 / float64(sizePerRecord)) + cases := []struct { + Count uint32 + ChannelID int32 + Err bool + SameIDWith int + NotSameIDWith int + ResultCount int32 + }{ + {1000, 1, false, -1, -1, 1000}, + {2000, 0, false, 0, -1, 2000}, + {maxCount - 2999, 1, false, -1, 0, int32(maxCount - 2999)}, + {maxCount - 3000, 1, false, 0, -1, int32(maxCount - 3000)}, + {2000000000, 1, true, -1, -1, -1}, + {1000, 3, false, -1, 0, 1000}, + {maxCount, 2, false, -1, -1, int32(maxCount)}, } - segAssigns, err := segMgr.AssignSegmentID(reqs) - assert.Nil(t, err) - - assert.Equal(t, uint32(25000), segAssigns[0].Count) - assert.Equal(t, uint32(10000), segAssigns[1].Count) - assert.Equal(t, uint32(30000), segAssigns[2].Count) - assert.Equal(t, uint32(25000), segAssigns[3].Count) - assert.Equal(t, uint32(10000), segAssigns[4].Count) - - assert.Equal(t, segAssigns[0].SegID, segAssigns[1].SegID) - assert.Equal(t, segAssigns[2].SegID, segAssigns[3].SegID) - - newReqs := []*internalpb.SegIDRequest{ - {CollName: collName, PartitionTag: partitionTag, Count: 500000, ChannelID: 0}, + var results = make([]*internalpb.SegIDAssignment, 0) + for _, c := range cases { + result, err := segManager.AssignSegment([]*internalpb.SegIDRequest{{Count: c.Count, ChannelID: c.ChannelID, CollName: collName, PartitionTag: partitionTag}}) + results = append(results, result...) + if c.Err { + assert.NotNil(t, err) + continue + } + assert.Nil(t, err) + if c.SameIDWith != -1 { + assert.EqualValues(t, result[0].SegID, results[c.SameIDWith].SegID) + } + if c.NotSameIDWith != -1 { + assert.NotEqualValues(t, result[0].SegID, results[c.NotSameIDWith].SegID) + } + if c.ResultCount != -1 { + assert.EqualValues(t, result[0].Count, c.ResultCount) + } } - // test open a new segment - newAssign, err := segMgr.AssignSegmentID(newReqs) + + time.Sleep(time.Duration(Params.SegIDAssignExpiration)) + timestamp, err := globalTsoAllocator() assert.Nil(t, err) - assert.NotNil(t, newAssign) - assert.Equal(t, uint32(500000), newAssign[0].Count) - assert.NotEqual(t, segAssigns[0].SegID, newAssign[0].SegID) - - // test assignment expiration - time.Sleep(3 * time.Second) - - assignAfterExpiration, err := segMgr.AssignSegmentID(newReqs) - assert.Nil(t, err) - assert.NotNil(t, assignAfterExpiration) - assert.Equal(t, uint32(500000), assignAfterExpiration[0].Count) - assert.Equal(t, segAssigns[0].SegID, assignAfterExpiration[0].SegID) - - // test invalid params - newReqs[0].CollName = "wrong_collname" - _, err = segMgr.AssignSegmentID(newReqs) - assert.Error(t, errors.Errorf("can not find collection with id=%d", collID), err) - - newReqs[0].Count = 1000000 - _, err = segMgr.AssignSegmentID(newReqs) - assert.Error(t, errors.Errorf("request with count %d need about %d mem size which is larger than segment threshold", - 1000000, 1024*1000000), err) -} - -func TestSegmentManager_SegmentStats(t *testing.T) { - setup() - defer teardown() - ts, err := segMgr.globalTSOAllocator() - assert.Nil(t, err) - err = mt.AddSegment(&pb.SegmentMeta{ - SegmentID: 100, + err = mt.UpdateSegment(&pb.SegmentMeta{ + SegmentID: results[0].SegID, CollectionID: collID, PartitionTag: partitionTag, ChannelStart: 0, ChannelEnd: 1, - OpenTime: ts, + CloseTime: timestamp, + NumRows: 400000, + MemSize: 500000, }) assert.Nil(t, err) - stats := internalpb.QueryNodeStats{ - MsgType: internalpb.MsgType_kQueryNodeStats, - PeerID: 1, - SegStats: []*internalpb.SegmentStats{ - {SegmentID: 100, MemorySize: 2500000, NumRows: 25000, RecentlyModified: true}, + tsMsg := &msgstream.TimeTickMsg{ + BaseMsg: msgstream.BaseMsg{ + BeginTimestamp: timestamp, EndTimestamp: timestamp, HashValues: []uint32{}, + }, + TimeTickMsg: internalpb.TimeTickMsg{ + MsgType: internalpb.MsgType_kTimeTick, + PeerID: 1, + Timestamp: timestamp, }, } - baseMsg := msgstream.BaseMsg{ - BeginTimestamp: 0, - EndTimestamp: 0, - HashValues: []uint32{1}, - } - msg := msgstream.QueryNodeStatsMsg{ - QueryNodeStats: stats, - BaseMsg: baseMsg, - } - - var tsMsg msgstream.TsMsg = &msg - msgPack := msgstream.MsgPack{ - Msgs: make([]msgstream.TsMsg, 0), - } - msgPack.Msgs = append(msgPack.Msgs, tsMsg) - err = segMgr.HandleQueryNodeMsgPack(&msgPack) + syncWriteChan <- tsMsg + time.Sleep(300 * time.Millisecond) + segMeta, err := mt.GetSegmentByID(results[0].SegID) assert.Nil(t, err) - - segMeta, _ := mt.GetSegmentByID(100) - assert.Equal(t, int64(100), segMeta.SegmentID) - assert.Equal(t, int64(2500000), segMeta.MemSize) - assert.Equal(t, int64(25000), segMeta.NumRows) - - // close segment - stats.SegStats[0].NumRows = 600000 - stats.SegStats[0].MemorySize = int64(0.8 * segMgr.segmentThreshold) - err = segMgr.HandleQueryNodeMsgPack(&msgPack) - assert.Nil(t, err) - segMeta, _ = mt.GetSegmentByID(100) - assert.Equal(t, int64(100), segMeta.SegmentID) - assert.NotEqual(t, uint64(0), segMeta.CloseTime) + assert.NotEqualValues(t, 0, segMeta.CloseTime) } - -func startupMaster() { +func TestSegmentManager_RPC(t *testing.T) { Init() refreshMasterAddress() etcdAddress := Params.EtcdAddress rootPath := "/test/root" ctx, cancel := context.WithCancel(context.TODO()) - masterCancelFunc = cancel + defer cancel() cli, err := clientv3.New(clientv3.Config{Endpoints: []string{etcdAddress}}) - if err != nil { - panic(err) - } + assert.Nil(t, err) _, err = cli.Delete(ctx, rootPath, clientv3.WithPrefix()) - if err != nil { - panic(err) - } + assert.Nil(t, err) Params = ParamTable{ Address: Params.Address, Port: Params.Port, @@ -268,27 +207,13 @@ func startupMaster() { DefaultPartitionTag: "_default", } - master, err = CreateServer(ctx) - if err != nil { - panic(err) - } + collName := "test_coll" + partitionTag := "test_part" + master, err := CreateServer(ctx) + assert.Nil(t, err) + defer master.Close() err = master.Run(int64(Params.Port)) - - if err != nil { - panic(err) - } -} - -func shutdownMaster() { - masterCancelFunc() - master.Close() -} - -func TestSegmentManager_RPC(t *testing.T) { - startupMaster() - defer shutdownMaster() - ctx, cancel := context.WithCancel(context.TODO()) - defer cancel() + assert.Nil(t, err) dialContext, err := grpc.DialContext(ctx, Params.Address, grpc.WithInsecure(), grpc.WithBlock()) assert.Nil(t, err) defer dialContext.Close() @@ -297,7 +222,10 @@ func TestSegmentManager_RPC(t *testing.T) { Name: collName, Description: "test coll", AutoID: false, - Fields: []*schemapb.FieldSchema{}, + Fields: []*schemapb.FieldSchema{ + {FieldID: 1, Name: "f1", IsPrimaryKey: false, DataType: schemapb.DataType_INT32}, + {FieldID: 1, Name: "f1", IsPrimaryKey: false, DataType: schemapb.DataType_VECTOR_FLOAT, TypeParams: []*commonpb.KeyValuePair{{Key: "dim", Value: "128"}}}, + }, } schemaBytes, err := proto.Marshal(schema) assert.Nil(t, err) @@ -329,13 +257,13 @@ func TestSegmentManager_RPC(t *testing.T) { }, }) assert.Nil(t, err) - assert.Equal(t, commonpb.ErrorCode_SUCCESS, resp.Status.ErrorCode) + assert.EqualValues(t, commonpb.ErrorCode_SUCCESS, resp.Status.ErrorCode) assignments := resp.GetPerChannelAssignment() - assert.Equal(t, 1, len(assignments)) - assert.Equal(t, collName, assignments[0].CollName) - assert.Equal(t, partitionTag, assignments[0].PartitionTag) - assert.Equal(t, int32(0), assignments[0].ChannelID) - assert.Equal(t, uint32(10000), assignments[0].Count) + assert.EqualValues(t, 1, len(assignments)) + assert.EqualValues(t, collName, assignments[0].CollName) + assert.EqualValues(t, partitionTag, assignments[0].PartitionTag) + assert.EqualValues(t, int32(0), assignments[0].ChannelID) + assert.EqualValues(t, uint32(10000), assignments[0].Count) // test stats segID := assignments[0].SegID @@ -369,6 +297,6 @@ func TestSegmentManager_RPC(t *testing.T) { time.Sleep(500 * time.Millisecond) segMeta, err := master.metaTable.GetSegmentByID(segID) assert.Nil(t, err) - assert.NotEqual(t, uint64(0), segMeta.GetCloseTime()) - assert.Equal(t, int64(600000000), segMeta.GetMemSize()) + assert.EqualValues(t, 1000000, segMeta.GetNumRows()) + assert.EqualValues(t, int64(600000000), segMeta.GetMemSize()) } diff --git a/internal/master/stats_processor.go b/internal/master/stats_processor.go new file mode 100644 index 0000000000..536123ad3f --- /dev/null +++ b/internal/master/stats_processor.go @@ -0,0 +1,58 @@ +package master + +import ( + "github.com/zilliztech/milvus-distributed/internal/errors" + "github.com/zilliztech/milvus-distributed/internal/msgstream" + "github.com/zilliztech/milvus-distributed/internal/proto/internalpb" +) + +type StatsProcessor struct { + metaTable *metaTable + segmentThreshold float64 + segmentThresholdFactor float64 + globalTSOAllocator func() (Timestamp, error) +} + +func (processor *StatsProcessor) ProcessQueryNodeStats(msgPack *msgstream.MsgPack) error { + for _, msg := range msgPack.Msgs { + statsMsg, ok := msg.(*msgstream.QueryNodeStatsMsg) + if !ok { + return errors.Errorf("Type of message is not QueryNodeSegStatsMsg") + } + + for _, segStat := range statsMsg.GetSegStats() { + err := processor.processSegmentStat(segStat) + if err != nil { + return err + } + } + } + + return nil +} + +func (processor *StatsProcessor) processSegmentStat(segStats *internalpb.SegmentStats) error { + if !segStats.GetRecentlyModified() { + return nil + } + + segID := segStats.GetSegmentID() + segMeta, err := processor.metaTable.GetSegmentByID(segID) + if err != nil { + return err + } + + segMeta.NumRows = segStats.NumRows + segMeta.MemSize = segStats.MemorySize + + return processor.metaTable.UpdateSegment(segMeta) +} + +func NewStatsProcessor(mt *metaTable, globalTSOAllocator func() (Timestamp, error)) *StatsProcessor { + return &StatsProcessor{ + metaTable: mt, + segmentThreshold: Params.SegmentSize * 1024 * 1024, + segmentThresholdFactor: Params.SegmentSizeFactor, + globalTSOAllocator: globalTSOAllocator, + } +} diff --git a/internal/master/time_sync_producer.go b/internal/master/time_sync_producer.go index 94615d57cc..ec7dac2e62 100644 --- a/internal/master/time_sync_producer.go +++ b/internal/master/time_sync_producer.go @@ -21,6 +21,9 @@ type timeSyncMsgProducer struct { ctx context.Context cancel context.CancelFunc + + proxyWatchers []chan *ms.TimeTickMsg + writeNodeWatchers []chan *ms.TimeTickMsg } func NewTimeSyncMsgProducer(ctx context.Context) (*timeSyncMsgProducer, error) { @@ -47,7 +50,15 @@ func (syncMsgProducer *timeSyncMsgProducer) SetK2sSyncStream(k2sSync ms.MsgStrea syncMsgProducer.k2sSyncStream = k2sSync } -func (syncMsgProducer *timeSyncMsgProducer) broadcastMsg(barrier TimeTickBarrier, streams []ms.MsgStream) error { +func (syncMsgProducer *timeSyncMsgProducer) WatchProxyTtBarrier(watcher chan *ms.TimeTickMsg) { + syncMsgProducer.proxyWatchers = append(syncMsgProducer.proxyWatchers, watcher) +} + +func (syncMsgProducer *timeSyncMsgProducer) WatchWriteNodeTtBarrier(watcher chan *ms.TimeTickMsg) { + syncMsgProducer.writeNodeWatchers = append(syncMsgProducer.writeNodeWatchers, watcher) +} + +func (syncMsgProducer *timeSyncMsgProducer) broadcastMsg(barrier TimeTickBarrier, streams []ms.MsgStream, channels []chan *ms.TimeTickMsg) error { for { select { case <-syncMsgProducer.ctx.Done(): @@ -79,6 +90,10 @@ func (syncMsgProducer *timeSyncMsgProducer) broadcastMsg(barrier TimeTickBarrier for _, stream := range streams { err = stream.Broadcast(&msgPack) } + + for _, channel := range channels { + channel <- timeTickMsg + } if err != nil { return err } @@ -97,8 +112,8 @@ func (syncMsgProducer *timeSyncMsgProducer) Start() error { return err } - go syncMsgProducer.broadcastMsg(syncMsgProducer.proxyTtBarrier, []ms.MsgStream{syncMsgProducer.dmSyncStream, syncMsgProducer.ddSyncStream}) - go syncMsgProducer.broadcastMsg(syncMsgProducer.writeNodeTtBarrier, []ms.MsgStream{syncMsgProducer.k2sSyncStream}) + go syncMsgProducer.broadcastMsg(syncMsgProducer.proxyTtBarrier, []ms.MsgStream{syncMsgProducer.dmSyncStream, syncMsgProducer.ddSyncStream}, syncMsgProducer.proxyWatchers) + go syncMsgProducer.broadcastMsg(syncMsgProducer.writeNodeTtBarrier, []ms.MsgStream{syncMsgProducer.k2sSyncStream}, syncMsgProducer.writeNodeWatchers) return nil } diff --git a/internal/util/typeutil/schema.go b/internal/util/typeutil/schema.go new file mode 100644 index 0000000000..5046e737a9 --- /dev/null +++ b/internal/util/typeutil/schema.go @@ -0,0 +1,36 @@ +package typeutil + +import ( + "strconv" + + "github.com/zilliztech/milvus-distributed/internal/proto/schemapb" +) + +func EstimateSizePerRecord(schema *schemapb.CollectionSchema) (int, error) { + res := 0 + for _, fs := range schema.Fields { + switch fs.DataType { + case schemapb.DataType_BOOL, schemapb.DataType_INT8: + res++ + case schemapb.DataType_INT16: + res += 2 + case schemapb.DataType_INT32, schemapb.DataType_FLOAT: + res += 4 + case schemapb.DataType_INT64, schemapb.DataType_DOUBLE: + res += 8 + case schemapb.DataType_STRING: + res += 125 // todo find a better way to estimate string type + case schemapb.DataType_VECTOR_BINARY, schemapb.DataType_VECTOR_FLOAT: + for _, kv := range fs.TypeParams { + if kv.Key == "dim" { + v, err := strconv.Atoi(kv.Value) + if err != nil { + return -1, err + } + res += v + } + } + } + } + return res, nil +}