Refactor segment manager

Signed-off-by: sunby <bingyi.sun@zilliz.com>
This commit is contained in:
sunby 2020-12-19 12:55:24 +08:00 committed by yefu.chen
parent e65cfe1e3d
commit c71cd40f68
13 changed files with 794 additions and 729 deletions

View File

@ -61,7 +61,7 @@ ruleguard:
verifiers: getdeps cppcheck fmt lint ruleguard
# Builds various components locally.
build-go: build-cpp
build-go:
@echo "Building each component's binary to './bin'"
@echo "Building master ..."
@mkdir -p $(INSTALL_PATH) && go env -w CGO_ENABLED="0" && GO111MODULE=on $(GO) build -o $(INSTALL_PATH)/master $(PWD)/cmd/master/main.go 1>/dev/null

View File

@ -20,7 +20,6 @@
#include <memory>
#include <boost/align/aligned_allocator.hpp>
#include <boost/algorithm/string.hpp>
#include <algorithm>
namespace milvus::query {
@ -40,8 +39,10 @@ const std::map<std::string, RangeExpr::OpType> RangeExpr::mapping_ = {
class Parser {
public:
friend std::unique_ptr<Plan>
CreatePlan(const Schema& schema, const std::string& dsl_str);
static std::unique_ptr<Plan>
CreatePlan(const Schema& schema, const std::string& dsl_str) {
return Parser(schema).CreatePlanImpl(dsl_str);
}
private:
std::unique_ptr<Plan>
@ -50,55 +51,29 @@ class Parser {
explicit Parser(const Schema& schema) : schema(schema) {
}
// vector node parser, should be called exactly once per pass.
std::unique_ptr<VectorPlanNode>
ParseVecNode(const Json& out_body);
// Dispatcher of all parse function
// NOTE: when nullptr, it is a pure vector node
ExprPtr
ParseAnyNode(const Json& body);
ExprPtr
ParseMustNode(const Json& body);
ExprPtr
ParseShouldNode(const Json& body);
ExprPtr
ParseShouldNotNode(const Json& body);
// parse the value of "should"/"must"/"should_not" entry
std::vector<ExprPtr>
ParseItemList(const Json& body);
// parse the value of "range" entry
ExprPtr
ParseRangeNode(const Json& out_body);
// parse the value of "term" entry
ExprPtr
ParseTermNode(const Json& out_body);
private:
// template implementation of leaf parser
// used by corresponding parser
template <typename T>
ExprPtr
std::unique_ptr<Expr>
ParseRangeNodeImpl(const std::string& field_name, const Json& body);
template <typename T>
ExprPtr
std::unique_ptr<Expr>
ParseTermNodeImpl(const std::string& field_name, const Json& body);
std::unique_ptr<Expr>
ParseRangeNode(const Json& out_body);
std::unique_ptr<Expr>
ParseTermNode(const Json& out_body);
private:
const Schema& schema;
std::map<std::string, FieldId> tag2field_; // PlaceholderName -> FieldId
std::optional<std::unique_ptr<VectorPlanNode>> vector_node_opt_;
};
ExprPtr
std::unique_ptr<Expr>
Parser::ParseRangeNode(const Json& out_body) {
Assert(out_body.is_object());
Assert(out_body.size() == 1);
@ -109,8 +84,9 @@ Parser::ParseRangeNode(const Json& out_body) {
Assert(!field_is_vector(data_type));
switch (data_type) {
case DataType::BOOL:
case DataType::BOOL: {
return ParseRangeNodeImpl<bool>(field_name, body);
}
case DataType::INT8:
return ParseRangeNodeImpl<int8_t>(field_name, body);
case DataType::INT16:
@ -130,22 +106,51 @@ Parser::ParseRangeNode(const Json& out_body) {
std::unique_ptr<Plan>
Parser::CreatePlanImpl(const std::string& dsl_str) {
auto dsl = Json::parse(dsl_str);
auto bool_dsl = dsl.at("bool");
auto predicate = ParseAnyNode(bool_dsl);
Assert(vector_node_opt_.has_value());
auto vec_node = std::move(vector_node_opt_).value();
if (predicate != nullptr) {
vec_node->predicate_ = std::move(predicate);
}
auto plan = std::make_unique<Plan>(schema);
auto dsl = nlohmann::json::parse(dsl_str);
nlohmann::json vec_pack;
std::optional<std::unique_ptr<Expr>> predicate;
// top level
auto& bool_dsl = dsl.at("bool");
if (bool_dsl.contains("must")) {
auto& packs = bool_dsl.at("must");
Assert(packs.is_array());
for (auto& pack : packs) {
if (pack.contains("vector")) {
auto& out_body = pack.at("vector");
plan->plan_node_ = ParseVecNode(out_body);
} else if (pack.contains("term")) {
AssertInfo(!predicate, "unsupported complex DSL");
auto& out_body = pack.at("term");
predicate = ParseTermNode(out_body);
} else if (pack.contains("range")) {
AssertInfo(!predicate, "unsupported complex DSL");
auto& out_body = pack.at("range");
predicate = ParseRangeNode(out_body);
} else {
PanicInfo("unsupported node");
}
}
AssertInfo(plan->plan_node_, "vector node not found");
} else if (bool_dsl.contains("vector")) {
auto& out_body = bool_dsl.at("vector");
plan->plan_node_ = ParseVecNode(out_body);
Assert(plan->plan_node_);
} else {
PanicInfo("Unsupported DSL");
}
plan->plan_node_->predicate_ = std::move(predicate);
plan->tag2field_ = std::move(tag2field_);
plan->plan_node_ = std::move(vec_node);
// TODO: target_entry parser
// if schema autoid is true,
// prepend target_entries_ with row_id
// else
// with primary_key
//
return plan;
}
ExprPtr
std::unique_ptr<Expr>
Parser::ParseTermNode(const Json& out_body) {
Assert(out_body.size() == 1);
auto out_iter = out_body.begin();
@ -216,7 +221,7 @@ Parser::ParseVecNode(const Json& out_body) {
}
template <typename T>
ExprPtr
std::unique_ptr<Expr>
Parser::ParseTermNodeImpl(const std::string& field_name, const Json& body) {
auto expr = std::make_unique<TermExprImpl<T>>();
auto data_type = schema[field_name].get_data_type();
@ -244,7 +249,7 @@ Parser::ParseTermNodeImpl(const std::string& field_name, const Json& body) {
}
template <typename T>
ExprPtr
std::unique_ptr<Expr>
Parser::ParseRangeNodeImpl(const std::string& field_name, const Json& body) {
auto expr = std::make_unique<RangeExprImpl<T>>();
auto data_type = schema[field_name].get_data_type();
@ -273,6 +278,12 @@ Parser::ParseRangeNodeImpl(const std::string& field_name, const Json& body) {
return expr;
}
std::unique_ptr<Plan>
CreatePlan(const Schema& schema, const std::string& dsl_str) {
auto plan = Parser::CreatePlan(schema, dsl_str);
return plan;
}
std::unique_ptr<PlaceholderGroup>
ParsePlaceholderGroup(const Plan* plan, const std::string& blob) {
namespace ser = milvus::proto::service;
@ -302,150 +313,6 @@ ParsePlaceholderGroup(const Plan* plan, const std::string& blob) {
return result;
}
std::unique_ptr<Plan>
CreatePlan(const Schema& schema, const std::string& dsl_str) {
auto plan = Parser(schema).CreatePlanImpl(dsl_str);
return plan;
}
std::vector<ExprPtr>
Parser::ParseItemList(const Json& body) {
std::vector<ExprPtr> results;
if (body.is_object()) {
// only one item;
auto new_entry = ParseAnyNode(body);
results.emplace_back(std::move(new_entry));
} else {
// item array
Assert(body.is_array());
for (auto& item : body) {
auto new_entry = ParseAnyNode(item);
results.emplace_back(std::move(new_entry));
}
}
auto old_size = results.size();
auto new_end = std::remove_if(results.begin(), results.end(), [](const ExprPtr& x) { return x == nullptr; });
results.resize(new_end - results.begin());
return results;
}
ExprPtr
Parser::ParseAnyNode(const Json& out_body) {
Assert(out_body.is_object());
Assert(out_body.size() == 1);
auto out_iter = out_body.begin();
auto key = out_iter.key();
auto body = out_iter.value();
if (key == "must") {
return ParseMustNode(body);
} else if (key == "should") {
return ParseShouldNode(body);
} else if (key == "should_not") {
return ParseShouldNotNode(body);
} else if (key == "range") {
return ParseRangeNode(body);
} else if (key == "term") {
return ParseTermNode(body);
} else if (key == "vector") {
auto vec_node = ParseVecNode(body);
Assert(!vector_node_opt_.has_value());
vector_node_opt_ = std::move(vec_node);
return nullptr;
} else {
PanicInfo("unsupported key: " + key);
}
}
template <typename Merger>
static ExprPtr
ConstructTree(Merger merger, std::vector<ExprPtr> item_list) {
if (item_list.size() == 0) {
return nullptr;
}
if (item_list.size() == 1) {
return std::move(item_list[0]);
}
// Note: use deque to construct a binary tree
// Op
// / \
// Op Op
// | \ | \
// A B C D
std::deque<ExprPtr> binary_queue;
for (auto& item : item_list) {
Assert(item != nullptr);
binary_queue.push_back(std::move(item));
}
while (binary_queue.size() > 1) {
auto left = std::move(binary_queue.front());
binary_queue.pop_front();
auto right = std::move(binary_queue.front());
binary_queue.pop_front();
binary_queue.push_back(merger(std::move(left), std::move(right)));
}
Assert(binary_queue.size() == 1);
return std::move(binary_queue.front());
}
ExprPtr
Parser::ParseMustNode(const Json& body) {
auto item_list = ParseItemList(body);
auto merger = [](ExprPtr left, ExprPtr right) {
using OpType = BoolBinaryExpr::OpType;
auto res = std::make_unique<BoolBinaryExpr>();
res->op_type_ = OpType::LogicalAnd;
res->left_ = std::move(left);
res->right_ = std::move(right);
return res;
};
return ConstructTree(merger, std::move(item_list));
}
ExprPtr
Parser::ParseShouldNode(const Json& body) {
auto item_list = ParseItemList(body);
Assert(item_list.size() >= 1);
auto merger = [](ExprPtr left, ExprPtr right) {
using OpType = BoolBinaryExpr::OpType;
auto res = std::make_unique<BoolBinaryExpr>();
res->op_type_ = OpType::LogicalOr;
res->left_ = std::move(left);
res->right_ = std::move(right);
return res;
};
return ConstructTree(merger, std::move(item_list));
}
ExprPtr
Parser::ParseShouldNotNode(const Json& body) {
auto item_list = ParseItemList(body);
Assert(item_list.size() >= 1);
auto merger = [](ExprPtr left, ExprPtr right) {
using OpType = BoolBinaryExpr::OpType;
auto res = std::make_unique<BoolBinaryExpr>();
res->op_type_ = OpType::LogicalAnd;
res->left_ = std::move(left);
res->right_ = std::move(right);
return res;
};
auto subtree = ConstructTree(merger, std::move(item_list));
using OpType = BoolUnaryExpr::OpType;
auto res = std::make_unique<BoolUnaryExpr>();
res->op_type_ = OpType::LogicalNot;
res->child_ = std::move(subtree);
return res;
}
int64_t
GetTopK(const Plan* plan) {
return plan->plan_node_->query_info_.topK_;

View File

@ -67,7 +67,6 @@ ExecExprVisitor::visit(BoolUnaryExpr& expr) {
switch (expr.op_type_) {
case OpType::LogicalNot: {
chunk.flip();
break;
}
default: {
PanicInfo("Invalid OpType");

View File

@ -410,104 +410,3 @@ TEST(Expr, TestTerm) {
}
}
}
TEST(Expr, TestSimpleDsl) {
using namespace milvus::query;
using namespace milvus::segcore;
auto vec_dsl = Json::parse(R"(
{
"vector": {
"fakevec": {
"metric_type": "L2",
"params": {
"nprobe": 10
},
"query": "$0",
"topk": 10
}
}
}
)");
int N = 32;
auto get_item = [&](int base, int bit = 1) {
std::vector<int> terms;
// note: random gen range is [0, 2N)
for (int i = 0; i < N * 2; ++i) {
if (((i >> base) & 0x1) == bit) {
terms.push_back(i);
}
}
Json s;
s["term"]["age"]["values"] = terms;
return s;
};
// std::cout << get_item(0).dump(-2);
// std::cout << vec_dsl.dump(-2);
std::vector<std::tuple<Json, std::function<bool(int)>>> testcases;
{
Json dsl;
dsl["must"] = Json::array({vec_dsl, get_item(0), get_item(1), get_item(2, 0), get_item(3)});
testcases.emplace_back(dsl, [](int x) { return (x & 0b1111) == 0b1011; });
}
{
Json dsl;
Json sub_dsl;
sub_dsl["must"] = Json::array({get_item(0), get_item(1), get_item(2, 0), get_item(3)});
dsl["must"] = Json::array({sub_dsl, vec_dsl});
testcases.emplace_back(dsl, [](int x) { return (x & 0b1111) == 0b1011; });
}
{
Json dsl;
Json sub_dsl;
sub_dsl["should"] = Json::array({get_item(0), get_item(1), get_item(2, 0), get_item(3)});
dsl["must"] = Json::array({sub_dsl, vec_dsl});
testcases.emplace_back(dsl, [](int x) { return !!((x & 0b1111) ^ 0b0100); });
}
{
Json dsl;
Json sub_dsl;
sub_dsl["should_not"] = Json::array({get_item(0), get_item(1), get_item(2, 0), get_item(3)});
dsl["must"] = Json::array({sub_dsl, vec_dsl});
testcases.emplace_back(dsl, [](int x) { return (x & 0b1111) != 0b1011; });
}
auto schema = std::make_shared<Schema>();
schema->AddField("fakevec", DataType::VECTOR_FLOAT, 16, MetricType::METRIC_L2);
schema->AddField("age", DataType::INT32);
auto seg = CreateSegment(schema);
std::vector<int> age_col;
int num_iters = 100;
for (int iter = 0; iter < num_iters; ++iter) {
auto raw_data = DataGen(schema, N, iter);
auto new_age_col = raw_data.get_col<int>(1);
age_col.insert(age_col.end(), new_age_col.begin(), new_age_col.end());
seg->PreInsert(N);
seg->Insert(iter * N, N, raw_data.row_ids_.data(), raw_data.timestamps_.data(), raw_data.raw_);
}
auto seg_promote = dynamic_cast<SegmentSmallIndex*>(seg.get());
ExecExprVisitor visitor(*seg_promote);
for (auto [clause, ref_func] : testcases) {
Json dsl;
dsl["bool"] = clause;
// std::cout << dsl.dump(2);
auto plan = CreatePlan(*schema, dsl.dump());
auto final = visitor.call_child(*plan->plan_node_->predicate_.value());
EXPECT_EQ(final.size(), upper_div(N * num_iters, DefaultElementPerChunk));
for (int i = 0; i < N * num_iters; ++i) {
auto vec_id = i / DefaultElementPerChunk;
auto offset = i % DefaultElementPerChunk;
bool ans = final[vec_id][offset];
auto val = age_col[i];
auto ref = ref_func(val);
ASSERT_EQ(ans, ref) << clause << "@" << i << "!!" << val;
}
}
}

View File

@ -437,7 +437,7 @@ func (s *Master) AllocID(ctx context.Context, request *internalpb.IDRequest) (*i
}
func (s *Master) AssignSegmentID(ctx context.Context, request *internalpb.AssignSegIDRequest) (*internalpb.AssignSegIDResponse, error) {
segInfos, err := s.segmentMgr.AssignSegmentID(request.GetPerChannelReq())
segInfos, err := s.segmentManager.AssignSegment(request.GetPerChannelReq())
if err != nil {
return &internalpb.AssignSegIDResponse{
Status: &commonpb.Status{ErrorCode: commonpb.ErrorCode_UNEXPECTED_ERROR},

View File

@ -54,7 +54,9 @@ type Master struct {
startCallbacks []func()
closeCallbacks []func()
segmentMgr *SegmentManager
segmentManager *SegmentManager
segmentAssigner *SegmentAssigner
statProcessor *StatsProcessor
segmentStatusMsg ms.MsgStream
//id allocator
@ -137,6 +139,11 @@ func CreateServer(ctx context.Context) (*Master, error) {
pulsarK2SStream.CreatePulsarProducers(Params.K2SChannelNames)
tsMsgProducer.SetK2sSyncStream(pulsarK2SStream)
proxyTtBarrierWatcher := make(chan *ms.TimeTickMsg, 1024)
writeNodeTtBarrierWatcher := make(chan *ms.TimeTickMsg, 1024)
tsMsgProducer.WatchProxyTtBarrier(proxyTtBarrierWatcher)
tsMsgProducer.WatchWriteNodeTtBarrier(writeNodeTtBarrierWatcher)
// stats msg stream
statsMs := ms.NewPulsarMsgStream(ctx, 1024)
statsMs.SetPulsarClient(pulsarAddr)
@ -169,9 +176,23 @@ func CreateServer(ctx context.Context) (*Master, error) {
m.scheduler.SetDDMsgStream(pulsarDDStream)
m.scheduler.SetIDAllocator(func() (UniqueID, error) { return m.idAllocator.AllocOne() })
m.segmentMgr = NewSegmentManager(metakv,
m.segmentAssigner = NewSegmentAssigner(ctx, metakv,
func() (Timestamp, error) { return m.tsoAllocator.AllocOne() },
proxyTtBarrierWatcher,
)
m.segmentManager, err = NewSegmentManager(ctx, metakv,
func() (UniqueID, error) { return m.idAllocator.AllocOne() },
func() (Timestamp, error) { return m.tsoAllocator.AllocOne() },
writeNodeTtBarrierWatcher,
&MockFlushScheduler{}, // todo replace mock with real flush scheduler
m.segmentAssigner)
if err != nil {
return nil, err
}
m.statProcessor = NewStatsProcessor(metakv,
func() (Timestamp, error) { return m.tsoAllocator.AllocOne() },
)
m.grpcServer = grpc.NewServer()
@ -199,7 +220,8 @@ func (s *Master) Close() {
log.Print("closing server")
s.stopServerLoop()
s.segmentAssigner.Close()
s.segmentManager.Close()
if s.kvBase != nil {
s.kvBase.Close()
}
@ -226,6 +248,8 @@ func (s *Master) Run(grpcPort int64) error {
if err := s.startServerLoop(s.ctx, grpcPort); err != nil {
return err
}
s.segmentAssigner.Start()
s.segmentManager.Start()
atomic.StoreInt64(&s.isServing, 1)
// Run callbacks
@ -268,7 +292,7 @@ func (s *Master) startServerLoop(ctx context.Context, grpcPort int64) error {
}
s.serverLoopWg.Add(1)
go s.segmentStatisticsLoop()
go s.statisticsLoop()
s.serverLoopWg.Add(1)
go s.tsLoop()
@ -348,7 +372,7 @@ func (s *Master) tsLoop() {
}
}
func (s *Master) segmentStatisticsLoop() {
func (s *Master) statisticsLoop() {
defer s.serverLoopWg.Done()
defer s.segmentStatusMsg.Close()
ctx, cancel := context.WithCancel(s.serverLoopCtx)
@ -357,7 +381,7 @@ func (s *Master) segmentStatisticsLoop() {
for {
select {
case msg := <-s.segmentStatusMsg.Chan():
err := s.segmentMgr.HandleQueryNodeMsgPack(msg)
err := s.statProcessor.ProcessQueryNodeStats(msg)
if err != nil {
log.Println(err)
}

View File

@ -0,0 +1,30 @@
package master
type persistenceScheduler interface {
Enqueue(interface{}) error
schedule(interface{}) error
scheduleLoop()
Start() error
Close()
}
type MockFlushScheduler struct {
}
func (m *MockFlushScheduler) Enqueue(i interface{}) error {
return nil
}
func (m *MockFlushScheduler) schedule(i interface{}) error {
return nil
}
func (m *MockFlushScheduler) scheduleLoop() {
}
func (m *MockFlushScheduler) Start() error {
return nil
}
func (m *MockFlushScheduler) Close() {
}

View File

@ -0,0 +1,187 @@
package master
import (
"context"
"log"
"sync"
"time"
"github.com/zilliztech/milvus-distributed/internal/util/tsoutil"
"github.com/zilliztech/milvus-distributed/internal/errors"
ms "github.com/zilliztech/milvus-distributed/internal/msgstream"
)
type Assignment struct {
rowNums int
expireTime Timestamp
}
type Status struct {
total int
lastExpireTime Timestamp
assignments []*Assignment
}
type SegmentAssigner struct {
mt *metaTable
segmentStatus map[UniqueID]*Status //segment id -> status
globalTSOAllocator func() (Timestamp, error)
segmentExpireDuration int64
proxyTimeSyncChan chan *ms.TimeTickMsg
ctx context.Context
cancel context.CancelFunc
waitGroup sync.WaitGroup
mu sync.Mutex
}
func (assigner *SegmentAssigner) OpenSegment(segmentID UniqueID, numRows int) error {
assigner.mu.Lock()
defer assigner.mu.Unlock()
if _, ok := assigner.segmentStatus[segmentID]; ok {
return errors.Errorf("can not reopen segment %d", segmentID)
}
newStatus := &Status{
total: numRows,
assignments: make([]*Assignment, 0),
}
assigner.segmentStatus[segmentID] = newStatus
return nil
}
func (assigner *SegmentAssigner) CloseSegment(segmentID UniqueID) error {
assigner.mu.Lock()
defer assigner.mu.Unlock()
if _, ok := assigner.segmentStatus[segmentID]; !ok {
return errors.Errorf("can not find segment %d", segmentID)
}
delete(assigner.segmentStatus, segmentID)
return nil
}
func (assigner *SegmentAssigner) Assign(segmentID UniqueID, numRows int) (bool, error) {
assigner.mu.Lock()
defer assigner.mu.Unlock()
status, ok := assigner.segmentStatus[segmentID]
if !ok {
return false, errors.Errorf("segment %d is not opened", segmentID)
}
allocated, err := assigner.totalOfAssignments(segmentID)
if err != nil {
return false, err
}
segMeta, err := assigner.mt.GetSegmentByID(segmentID)
if err != nil {
return false, err
}
free := status.total - int(segMeta.NumRows) - allocated
if numRows > free {
return false, nil
}
ts, err := assigner.globalTSOAllocator()
if err != nil {
return false, err
}
physicalTs, logicalTs := tsoutil.ParseTS(ts)
expirePhysicalTs := physicalTs.Add(time.Duration(assigner.segmentExpireDuration))
expireTs := tsoutil.ComposeTS(expirePhysicalTs.UnixNano()/int64(time.Millisecond), int64(logicalTs))
status.lastExpireTime = expireTs
status.assignments = append(status.assignments, &Assignment{
numRows,
ts,
})
return true, nil
}
func (assigner *SegmentAssigner) CheckAssignmentExpired(segmentID UniqueID, timestamp Timestamp) (bool, error) {
assigner.mu.Lock()
defer assigner.mu.Unlock()
status, ok := assigner.segmentStatus[segmentID]
if !ok {
return false, errors.Errorf("can not find segment %d", segmentID)
}
if timestamp >= status.lastExpireTime {
return true, nil
}
return false, nil
}
func (assigner *SegmentAssigner) Start() {
assigner.waitGroup.Add(1)
go assigner.startProxyTimeSync()
}
func (assigner *SegmentAssigner) Close() {
assigner.cancel()
assigner.waitGroup.Wait()
}
func (assigner *SegmentAssigner) startProxyTimeSync() {
defer assigner.waitGroup.Done()
for {
select {
case <-assigner.ctx.Done():
log.Println("proxy time sync stopped")
return
case msg := <-assigner.proxyTimeSyncChan:
if err := assigner.syncProxyTimeStamp(msg.TimeTickMsg.Timestamp); err != nil {
log.Println("proxy time sync error: " + err.Error())
}
}
}
}
func (assigner *SegmentAssigner) totalOfAssignments(segmentID UniqueID) (int, error) {
if _, ok := assigner.segmentStatus[segmentID]; !ok {
return -1, errors.Errorf("can not find segment %d", segmentID)
}
status := assigner.segmentStatus[segmentID]
res := 0
for _, v := range status.assignments {
res += v.rowNums
}
return res, nil
}
func (assigner *SegmentAssigner) syncProxyTimeStamp(timeTick Timestamp) error {
assigner.mu.Lock()
defer assigner.mu.Unlock()
for _, status := range assigner.segmentStatus {
for i := 0; i < len(status.assignments); {
if timeTick >= status.assignments[i].expireTime {
status.assignments[i] = status.assignments[len(status.assignments)-1]
status.assignments = status.assignments[:len(status.assignments)-1]
continue
}
i++
}
}
return nil
}
func NewSegmentAssigner(ctx context.Context, metaTable *metaTable,
globalTSOAllocator func() (Timestamp, error), proxyTimeSyncChan chan *ms.TimeTickMsg) *SegmentAssigner {
assignCtx, cancel := context.WithCancel(ctx)
return &SegmentAssigner{
mt: metaTable,
segmentStatus: make(map[UniqueID]*Status),
globalTSOAllocator: globalTSOAllocator,
segmentExpireDuration: Params.SegIDAssignExpiration,
proxyTimeSyncChan: proxyTimeSyncChan,
ctx: assignCtx,
cancel: cancel,
}
}

View File

@ -1,250 +1,191 @@
package master
import (
"context"
"log"
"sync"
"time"
"github.com/zilliztech/milvus-distributed/internal/util/typeutil"
"github.com/zilliztech/milvus-distributed/internal/proto/etcdpb"
"github.com/zilliztech/milvus-distributed/internal/errors"
"github.com/zilliztech/milvus-distributed/internal/msgstream"
"github.com/zilliztech/milvus-distributed/internal/proto/etcdpb"
"github.com/zilliztech/milvus-distributed/internal/proto/internalpb"
ms "github.com/zilliztech/milvus-distributed/internal/msgstream"
)
type collectionStatus struct {
openedSegments []UniqueID
segments []*segmentStatus
}
type assignment struct {
MemSize int64 // bytes
AssignTime time.Time
type segmentStatus struct {
segmentID UniqueID
total int
closable bool
}
type channelRange struct {
channelStart int32
channelEnd int32
}
type segmentStatus struct {
assignments []*assignment
}
type SegmentManager struct {
metaTable *metaTable
statsStream msgstream.MsgStream
channelRanges []*channelRange
segmentStatus map[UniqueID]*segmentStatus // segment id to segment status
collStatus map[UniqueID]*collectionStatus // collection id to collection status
defaultSizePerRecord int64
minimumAssignSize int64
segmentThreshold float64
segmentThresholdFactor float64
segmentExpireDuration int64
numOfChannels int
numOfQueryNodes int
globalIDAllocator func() (UniqueID, error)
globalTSOAllocator func() (Timestamp, error)
mu sync.RWMutex
assigner *SegmentAssigner
writeNodeTimeSyncChan chan *ms.TimeTickMsg
flushScheduler persistenceScheduler
ctx context.Context
cancel context.CancelFunc
waitGroup sync.WaitGroup
}
func (segMgr *SegmentManager) HandleQueryNodeMsgPack(msgPack *msgstream.MsgPack) error {
segMgr.mu.Lock()
defer segMgr.mu.Unlock()
for _, msg := range msgPack.Msgs {
statsMsg, ok := msg.(*msgstream.QueryNodeStatsMsg)
if !ok {
return errors.Errorf("Type of message is not QueryNodeStatsMsg")
}
func (manager *SegmentManager) AssignSegment(segIDReq []*internalpb.SegIDRequest) ([]*internalpb.SegIDAssignment, error) {
manager.mu.Lock()
defer manager.mu.Unlock()
for _, segStat := range statsMsg.GetSegStats() {
err := segMgr.handleSegmentStat(segStat)
if err != nil {
return err
}
}
}
return nil
}
func (segMgr *SegmentManager) handleSegmentStat(segStats *internalpb.SegmentStats) error {
if !segStats.GetRecentlyModified() {
return nil
}
segID := segStats.GetSegmentID()
segMeta, err := segMgr.metaTable.GetSegmentByID(segID)
if err != nil {
return err
}
segMeta.NumRows = segStats.NumRows
segMeta.MemSize = segStats.MemorySize
if segStats.MemorySize > int64(segMgr.segmentThresholdFactor*segMgr.segmentThreshold) {
if err := segMgr.metaTable.UpdateSegment(segMeta); err != nil {
return err
}
return segMgr.closeSegment(segMeta)
}
return segMgr.metaTable.UpdateSegment(segMeta)
}
func (segMgr *SegmentManager) closeSegment(segMeta *etcdpb.SegmentMeta) error {
if segMeta.GetCloseTime() == 0 {
// close the segment and remove from collStatus
collStatus, ok := segMgr.collStatus[segMeta.GetCollectionID()]
if ok {
openedSegments := collStatus.openedSegments
for i, openedSegID := range openedSegments {
if openedSegID == segMeta.SegmentID {
openedSegments[i] = openedSegments[len(openedSegments)-1]
collStatus.openedSegments = openedSegments[:len(openedSegments)-1]
break
}
}
}
ts, err := segMgr.globalTSOAllocator()
if err != nil {
return err
}
segMeta.CloseTime = ts
}
err := segMgr.metaTable.CloseSegment(segMeta.SegmentID, segMeta.GetCloseTime())
if err != nil {
return err
}
return nil
}
func (segMgr *SegmentManager) AssignSegmentID(segIDReq []*internalpb.SegIDRequest) ([]*internalpb.SegIDAssignment, error) {
segMgr.mu.Lock()
defer segMgr.mu.Unlock()
res := make([]*internalpb.SegIDAssignment, 0)
for _, req := range segIDReq {
collName := req.CollName
partitionTag := req.PartitionTag
count := req.Count
channelID := req.ChannelID
collMeta, err := segMgr.metaTable.GetCollectionByName(collName)
collMeta, err := manager.metaTable.GetCollectionByName(collName)
if err != nil {
return nil, err
}
collID := collMeta.GetID()
if !segMgr.metaTable.HasCollection(collID) {
return nil, errors.Errorf("can not find collection with id=%d", collID)
}
if !segMgr.metaTable.HasPartition(collID, partitionTag) {
if !manager.metaTable.HasPartition(collID, partitionTag) {
return nil, errors.Errorf("partition tag %s can not find in coll %d", partitionTag, collID)
}
collStatus, ok := segMgr.collStatus[collID]
if !ok {
collStatus = &collectionStatus{
openedSegments: make([]UniqueID, 0),
}
segMgr.collStatus[collID] = collStatus
}
assignInfo, err := segMgr.assignSegment(collName, collID, partitionTag, count, channelID, collStatus)
assignInfo, err := manager.assignSegment(collName, collID, partitionTag, count, channelID)
if err != nil {
return nil, err
}
res = append(res, assignInfo)
}
return res, nil
}
func (segMgr *SegmentManager) assignSegment(collName string, collID UniqueID, partitionTag string, count uint32, channelID int32,
collStatus *collectionStatus) (*internalpb.SegIDAssignment, error) {
segmentThreshold := int64(segMgr.segmentThreshold)
for _, segID := range collStatus.openedSegments {
segMeta, _ := segMgr.metaTable.GetSegmentByID(segID)
if segMeta.GetCloseTime() != 0 || channelID < segMeta.GetChannelStart() ||
channelID > segMeta.GetChannelEnd() || segMeta.PartitionTag != partitionTag {
func (manager *SegmentManager) assignSegment(
collName string,
collID UniqueID,
partitionTag string,
count uint32,
channelID int32) (*internalpb.SegIDAssignment, error) {
collStatus, ok := manager.collStatus[collID]
if !ok {
collStatus = &collectionStatus{
segments: make([]*segmentStatus, 0),
}
manager.collStatus[collID] = collStatus
}
for _, segStatus := range collStatus.segments {
if segStatus.closable {
continue
}
// check whether segment has enough mem size
assignedMem := segMgr.checkAssignedSegExpire(segID)
memSize := segMeta.MemSize
neededMemSize := segMgr.calNeededSize(memSize, segMeta.NumRows, int64(count))
if memSize+assignedMem+neededMemSize <= segmentThreshold {
remainingSize := segmentThreshold - memSize - assignedMem
allocMemSize := segMgr.calAllocMemSize(neededMemSize, remainingSize)
segMgr.addAssignment(segID, allocMemSize)
return &internalpb.SegIDAssignment{
SegID: segID,
ChannelID: channelID,
Count: uint32(segMgr.calNumRows(memSize, segMeta.NumRows, allocMemSize)),
CollName: collName,
PartitionTag: partitionTag,
}, nil
match, err := manager.isMatch(segStatus.segmentID, partitionTag, channelID)
if err != nil {
return nil, err
}
}
neededMemSize := segMgr.defaultSizePerRecord * int64(count)
if neededMemSize > segmentThreshold {
return nil, errors.Errorf("request with count %d need about %d mem size which is larger than segment threshold",
count, neededMemSize)
if !match {
continue
}
result, err := manager.assigner.Assign(segStatus.segmentID, int(count))
if err != nil {
return nil, err
}
if !result {
continue
}
return &internalpb.SegIDAssignment{
SegID: segStatus.segmentID,
ChannelID: channelID,
Count: count,
CollName: collName,
PartitionTag: partitionTag,
}, nil
}
segMeta, err := segMgr.openNewSegment(channelID, collID, partitionTag)
total, err := manager.estimateTotalRows(collName)
if err != nil {
return nil, err
}
if int(count) > total {
return nil, errors.Errorf("request count %d is larger than total rows %d", count, total)
}
id, err := manager.openNewSegment(channelID, collID, partitionTag, total)
if err != nil {
return nil, err
}
allocMemSize := segMgr.calAllocMemSize(neededMemSize, segmentThreshold)
segMgr.addAssignment(segMeta.SegmentID, allocMemSize)
result, err := manager.assigner.Assign(id, int(count))
if err != nil {
return nil, err
}
if !result {
return nil, errors.Errorf("assign failed for segment %d", id)
}
return &internalpb.SegIDAssignment{
SegID: segMeta.SegmentID,
SegID: id,
ChannelID: channelID,
Count: uint32(segMgr.calNumRows(0, 0, allocMemSize)),
Count: count,
CollName: collName,
PartitionTag: partitionTag,
}, nil
}
func (segMgr *SegmentManager) addAssignment(segID UniqueID, allocSize int64) {
segStatus := segMgr.segmentStatus[segID]
segStatus.assignments = append(segStatus.assignments, &assignment{
MemSize: allocSize,
AssignTime: time.Now(),
})
func (manager *SegmentManager) isMatch(segmentID UniqueID, partitionTag string, channelID int32) (bool, error) {
segMeta, err := manager.metaTable.GetSegmentByID(segmentID)
if err != nil {
return false, err
}
if channelID < segMeta.GetChannelStart() ||
channelID > segMeta.GetChannelEnd() || segMeta.PartitionTag != partitionTag {
return false, nil
}
return true, nil
}
func (segMgr *SegmentManager) calNeededSize(memSize int64, numRows int64, count int64) int64 {
var avgSize int64
if memSize == 0 || numRows == 0 || memSize/numRows == 0 {
avgSize = segMgr.defaultSizePerRecord
} else {
avgSize = memSize / numRows
func (manager *SegmentManager) estimateTotalRows(collName string) (int, error) {
collMeta, err := manager.metaTable.GetCollectionByName(collName)
if err != nil {
return -1, err
}
return avgSize * count
sizePerRecord, err := typeutil.EstimateSizePerRecord(collMeta.Schema)
if err != nil {
return -1, err
}
return int(manager.segmentThreshold / float64(sizePerRecord)), nil
}
func (segMgr *SegmentManager) calAllocMemSize(neededSize int64, remainSize int64) int64 {
if neededSize > remainSize {
return 0
}
if remainSize < segMgr.minimumAssignSize {
return remainSize
}
if neededSize < segMgr.minimumAssignSize {
return segMgr.minimumAssignSize
}
return neededSize
}
func (segMgr *SegmentManager) calNumRows(memSize int64, numRows int64, allocMemSize int64) int64 {
var avgSize int64
if memSize == 0 || numRows == 0 || memSize/numRows == 0 {
avgSize = segMgr.defaultSizePerRecord
} else {
avgSize = memSize / numRows
}
return allocMemSize / avgSize
}
func (segMgr *SegmentManager) openNewSegment(channelID int32, collID UniqueID, partitionTag string) (*etcdpb.SegmentMeta, error) {
func (manager *SegmentManager) openNewSegment(channelID int32, collID UniqueID, partitionTag string, numRows int) (UniqueID, error) {
// find the channel range
channelStart, channelEnd := int32(-1), int32(-1)
for _, r := range segMgr.channelRanges {
for _, r := range manager.channelRanges {
if channelID >= r.channelStart && channelID <= r.channelEnd {
channelStart = r.channelStart
channelEnd = r.channelEnd
@ -252,18 +193,19 @@ func (segMgr *SegmentManager) openNewSegment(channelID int32, collID UniqueID, p
}
}
if channelStart == -1 {
return nil, errors.Errorf("can't find the channel range which contains channel %d", channelID)
return -1, errors.Errorf("can't find the channel range which contains channel %d", channelID)
}
newID, err := segMgr.globalIDAllocator()
newID, err := manager.globalIDAllocator()
if err != nil {
return nil, err
return -1, err
}
openTime, err := segMgr.globalTSOAllocator()
openTime, err := manager.globalTSOAllocator()
if err != nil {
return nil, err
return -1, err
}
newSegMeta := &etcdpb.SegmentMeta{
err = manager.metaTable.AddSegment(&etcdpb.SegmentMeta{
SegmentID: newID,
CollectionID: collID,
PartitionTag: partitionTag,
@ -272,51 +214,119 @@ func (segMgr *SegmentManager) openNewSegment(channelID int32, collID UniqueID, p
OpenTime: openTime,
NumRows: 0,
MemSize: 0,
}
err = segMgr.metaTable.AddSegment(newSegMeta)
})
if err != nil {
return nil, err
return -1, err
}
segMgr.segmentStatus[newID] = &segmentStatus{
assignments: make([]*assignment, 0),
err = manager.assigner.OpenSegment(newID, numRows)
if err != nil {
return -1, err
}
collStatus := segMgr.collStatus[collID]
collStatus.openedSegments = append(collStatus.openedSegments, newSegMeta.SegmentID)
return newSegMeta, nil
segStatus := &segmentStatus{
segmentID: newID,
total: numRows,
closable: false,
}
collStatus := manager.collStatus[collID]
collStatus.segments = append(collStatus.segments, segStatus)
return newID, nil
}
// checkAssignedSegExpire check the expire time of assignments and return the total sum of assignments that are not expired.
func (segMgr *SegmentManager) checkAssignedSegExpire(segID UniqueID) int64 {
segStatus := segMgr.segmentStatus[segID]
assignments := segStatus.assignments
result := int64(0)
i := 0
for i < len(assignments) {
assign := assignments[i]
if time.Since(assign.AssignTime) >= time.Duration(segMgr.segmentExpireDuration)*time.Millisecond {
assignments[i] = assignments[len(assignments)-1]
assignments = assignments[:len(assignments)-1]
continue
func (manager *SegmentManager) Start() {
manager.waitGroup.Add(1)
go manager.startWriteNodeTimeSync()
}
func (manager *SegmentManager) Close() {
manager.cancel()
manager.waitGroup.Wait()
}
func (manager *SegmentManager) startWriteNodeTimeSync() {
defer manager.waitGroup.Done()
for {
select {
case <-manager.ctx.Done():
log.Println("write node time sync stopped")
return
case msg := <-manager.writeNodeTimeSyncChan:
if err := manager.syncWriteNodeTimestamp(msg.TimeTickMsg.Timestamp); err != nil {
log.Println("write node time sync error: " + err.Error())
}
}
result += assign.MemSize
i++
}
segStatus.assignments = assignments
return result
}
func (segMgr *SegmentManager) createChannelRanges() error {
div, rem := segMgr.numOfChannels/segMgr.numOfQueryNodes, segMgr.numOfChannels%segMgr.numOfQueryNodes
for i, j := 0, 0; i < segMgr.numOfChannels; j++ {
func (manager *SegmentManager) syncWriteNodeTimestamp(timeTick Timestamp) error {
manager.mu.Lock()
defer manager.mu.Unlock()
for _, status := range manager.collStatus {
for i, segStatus := range status.segments {
if !segStatus.closable {
closable, err := manager.judgeSegmentClosable(segStatus)
if err != nil {
return err
}
segStatus.closable = closable
if !segStatus.closable {
continue
}
}
isExpired, err := manager.assigner.CheckAssignmentExpired(segStatus.segmentID, timeTick)
if err != nil {
return err
}
if !isExpired {
continue
}
status.segments = append(status.segments[:i], status.segments[i+1:]...)
ts, err := manager.globalTSOAllocator()
if err != nil {
return err
}
if err = manager.metaTable.CloseSegment(segStatus.segmentID, ts); err != nil {
return err
}
if err = manager.assigner.CloseSegment(segStatus.segmentID); err != nil {
return err
}
if err = manager.flushScheduler.Enqueue(segStatus.segmentID); err != nil {
return err
}
}
}
return nil
}
func (manager *SegmentManager) judgeSegmentClosable(status *segmentStatus) (bool, error) {
segMeta, err := manager.metaTable.GetSegmentByID(status.segmentID)
if err != nil {
return false, err
}
if segMeta.NumRows >= int64(manager.segmentThresholdFactor*float64(status.total)) {
return true, nil
}
return false, nil
}
func (manager *SegmentManager) initChannelRanges() error {
div, rem := manager.numOfChannels/manager.numOfQueryNodes, manager.numOfChannels%manager.numOfQueryNodes
for i, j := 0, 0; i < manager.numOfChannels; j++ {
if j < rem {
segMgr.channelRanges = append(segMgr.channelRanges, &channelRange{
manager.channelRanges = append(manager.channelRanges, &channelRange{
channelStart: int32(i),
channelEnd: int32(i + div),
})
i += div + 1
} else {
segMgr.channelRanges = append(segMgr.channelRanges, &channelRange{
manager.channelRanges = append(manager.channelRanges, &channelRange{
channelStart: int32(i),
channelEnd: int32(i + div - 1),
})
@ -325,26 +335,38 @@ func (segMgr *SegmentManager) createChannelRanges() error {
}
return nil
}
func NewSegmentManager(meta *metaTable,
func NewSegmentManager(ctx context.Context,
meta *metaTable,
globalIDAllocator func() (UniqueID, error),
globalTSOAllocator func() (Timestamp, error),
) *SegmentManager {
segMgr := &SegmentManager{
syncWriteNodeChan chan *ms.TimeTickMsg,
scheduler persistenceScheduler,
assigner *SegmentAssigner) (*SegmentManager, error) {
assignerCtx, cancel := context.WithCancel(ctx)
segAssigner := &SegmentManager{
metaTable: meta,
channelRanges: make([]*channelRange, 0),
segmentStatus: make(map[UniqueID]*segmentStatus),
collStatus: make(map[UniqueID]*collectionStatus),
segmentThreshold: Params.SegmentSize * 1024 * 1024,
segmentThresholdFactor: Params.SegmentSizeFactor,
segmentExpireDuration: Params.SegIDAssignExpiration,
minimumAssignSize: Params.MinSegIDAssignCnt * Params.DefaultRecordSize,
defaultSizePerRecord: Params.DefaultRecordSize,
numOfChannels: Params.TopicNum,
numOfQueryNodes: Params.QueryNodeNum,
globalIDAllocator: globalIDAllocator,
globalTSOAllocator: globalTSOAllocator,
assigner: assigner,
writeNodeTimeSyncChan: syncWriteNodeChan,
flushScheduler: scheduler,
ctx: assignerCtx,
cancel: cancel,
}
segMgr.createChannelRanges()
return segMgr
if err := segAssigner.initChannelRanges(); err != nil {
return nil, err
}
return segAssigner, nil
}

View File

@ -2,15 +2,12 @@ package master
import (
"context"
"log"
"sync/atomic"
"testing"
"time"
"github.com/golang/protobuf/proto"
"github.com/stretchr/testify/assert"
"github.com/zilliztech/milvus-distributed/internal/errors"
"github.com/zilliztech/milvus-distributed/internal/kv"
etcdkv "github.com/zilliztech/milvus-distributed/internal/kv/etcd"
"github.com/zilliztech/milvus-distributed/internal/msgstream"
"github.com/zilliztech/milvus-distributed/internal/proto/commonpb"
@ -25,61 +22,11 @@ import (
"google.golang.org/grpc"
)
var mt *metaTable
var segMgr *SegmentManager
var collName = "coll_segmgr_test"
var collID = int64(1001)
var partitionTag = "test"
var kvBase kv.TxnBase
var master *Master
var masterCancelFunc context.CancelFunc
func TestSegmentManager_AssignSegment(t *testing.T) {
ctx, cancelFunc := context.WithCancel(context.TODO())
defer cancelFunc()
func setup() {
Init()
etcdAddress := Params.EtcdAddress
cli, err := clientv3.New(clientv3.Config{Endpoints: []string{etcdAddress}})
if err != nil {
panic(err)
}
rootPath := "/test/root"
ctx, cancel := context.WithCancel(context.TODO())
defer cancel()
_, err = cli.Delete(ctx, rootPath, clientv3.WithPrefix())
if err != nil {
panic(err)
}
kvBase = etcdkv.NewEtcdKV(cli, rootPath)
tmpMt, err := NewMetaTable(kvBase)
if err != nil {
panic(err)
}
mt = tmpMt
if mt.HasCollection(collID) {
err := mt.DeleteCollection(collID)
if err != nil {
panic(err)
}
}
err = mt.AddCollection(&pb.CollectionMeta{
ID: collID,
Schema: &schemapb.CollectionSchema{
Name: collName,
},
CreateTime: 0,
SegmentIDs: []UniqueID{},
PartitionTags: []string{},
})
if err != nil {
panic(err)
}
err = mt.AddPartition(collID, partitionTag)
if err != nil {
panic(err)
}
var cnt int64
Params.TopicNum = 5
Params.QueryNodeNum = 3
Params.SegmentSize = 536870912 / 1024 / 1024
@ -87,151 +34,143 @@ func setup() {
Params.DefaultRecordSize = 1024
Params.MinSegIDAssignCnt = 1048576 / 1024
Params.SegIDAssignExpiration = 2000
etcdAddress := Params.EtcdAddress
cli, err := clientv3.New(clientv3.Config{Endpoints: []string{etcdAddress}})
assert.Nil(t, err)
rootPath := "/test/root"
_, err = cli.Delete(ctx, rootPath, clientv3.WithPrefix())
assert.Nil(t, err)
segMgr = NewSegmentManager(mt,
func() (UniqueID, error) {
val := atomic.AddInt64(&cnt, 1)
return val, nil
kvBase := etcdkv.NewEtcdKV(cli, rootPath)
defer kvBase.Close()
mt, err := NewMetaTable(kvBase)
assert.Nil(t, err)
collName := "segmgr_test_coll"
var collID int64 = 1001
partitionTag := "test_part"
schema := &schemapb.CollectionSchema{
Name: collName,
Fields: []*schemapb.FieldSchema{
{FieldID: 1, Name: "f1", IsPrimaryKey: false, DataType: schemapb.DataType_INT32},
{FieldID: 2, Name: "f2", IsPrimaryKey: false, DataType: schemapb.DataType_VECTOR_FLOAT, TypeParams: []*commonpb.KeyValuePair{
{Key: "dim", Value: "128"},
}},
},
func() (Timestamp, error) {
val := atomic.AddInt64(&cnt, 1)
phy := time.Now().UnixNano() / int64(time.Millisecond)
ts := tsoutil.ComposeTS(phy, val)
return ts, nil
},
)
}
func teardown() {
err := mt.DeleteCollection(collID)
if err != nil {
log.Fatalf(err.Error())
}
kvBase.Close()
}
err = mt.AddCollection(&pb.CollectionMeta{
ID: collID,
Schema: schema,
CreateTime: 0,
SegmentIDs: []UniqueID{},
PartitionTags: []string{},
})
assert.Nil(t, err)
err = mt.AddPartition(collID, partitionTag)
assert.Nil(t, err)
func TestSegmentManager_AssignSegmentID(t *testing.T) {
setup()
defer teardown()
reqs := []*internalpb.SegIDRequest{
{CollName: collName, PartitionTag: partitionTag, Count: 25000, ChannelID: 0},
{CollName: collName, PartitionTag: partitionTag, Count: 10000, ChannelID: 1},
{CollName: collName, PartitionTag: partitionTag, Count: 30000, ChannelID: 2},
{CollName: collName, PartitionTag: partitionTag, Count: 25000, ChannelID: 3},
{CollName: collName, PartitionTag: partitionTag, Count: 10000, ChannelID: 4},
var cnt int64
globalIDAllocator := func() (UniqueID, error) {
val := atomic.AddInt64(&cnt, 1)
return val, nil
}
globalTsoAllocator := func() (Timestamp, error) {
val := atomic.AddInt64(&cnt, 1)
phy := time.Now().UnixNano() / int64(time.Millisecond)
ts := tsoutil.ComposeTS(phy, val)
return ts, nil
}
syncWriteChan := make(chan *msgstream.TimeTickMsg)
syncProxyChan := make(chan *msgstream.TimeTickMsg)
segAssigner := NewSegmentAssigner(ctx, mt, globalTsoAllocator, syncProxyChan)
mockScheduler := &MockFlushScheduler{}
segManager, err := NewSegmentManager(ctx, mt, globalIDAllocator, globalTsoAllocator, syncWriteChan, mockScheduler, segAssigner)
assert.Nil(t, err)
segManager.Start()
defer segManager.Close()
sizePerRecord, err := typeutil.EstimateSizePerRecord(schema)
assert.Nil(t, err)
maxCount := uint32(Params.SegmentSize * 1024 * 1024 / float64(sizePerRecord))
cases := []struct {
Count uint32
ChannelID int32
Err bool
SameIDWith int
NotSameIDWith int
ResultCount int32
}{
{1000, 1, false, -1, -1, 1000},
{2000, 0, false, 0, -1, 2000},
{maxCount - 2999, 1, false, -1, 0, int32(maxCount - 2999)},
{maxCount - 3000, 1, false, 0, -1, int32(maxCount - 3000)},
{2000000000, 1, true, -1, -1, -1},
{1000, 3, false, -1, 0, 1000},
{maxCount, 2, false, -1, -1, int32(maxCount)},
}
segAssigns, err := segMgr.AssignSegmentID(reqs)
assert.Nil(t, err)
assert.Equal(t, uint32(25000), segAssigns[0].Count)
assert.Equal(t, uint32(10000), segAssigns[1].Count)
assert.Equal(t, uint32(30000), segAssigns[2].Count)
assert.Equal(t, uint32(25000), segAssigns[3].Count)
assert.Equal(t, uint32(10000), segAssigns[4].Count)
assert.Equal(t, segAssigns[0].SegID, segAssigns[1].SegID)
assert.Equal(t, segAssigns[2].SegID, segAssigns[3].SegID)
newReqs := []*internalpb.SegIDRequest{
{CollName: collName, PartitionTag: partitionTag, Count: 500000, ChannelID: 0},
var results = make([]*internalpb.SegIDAssignment, 0)
for _, c := range cases {
result, err := segManager.AssignSegment([]*internalpb.SegIDRequest{{Count: c.Count, ChannelID: c.ChannelID, CollName: collName, PartitionTag: partitionTag}})
results = append(results, result...)
if c.Err {
assert.NotNil(t, err)
continue
}
assert.Nil(t, err)
if c.SameIDWith != -1 {
assert.EqualValues(t, result[0].SegID, results[c.SameIDWith].SegID)
}
if c.NotSameIDWith != -1 {
assert.NotEqualValues(t, result[0].SegID, results[c.NotSameIDWith].SegID)
}
if c.ResultCount != -1 {
assert.EqualValues(t, result[0].Count, c.ResultCount)
}
}
// test open a new segment
newAssign, err := segMgr.AssignSegmentID(newReqs)
time.Sleep(time.Duration(Params.SegIDAssignExpiration))
timestamp, err := globalTsoAllocator()
assert.Nil(t, err)
assert.NotNil(t, newAssign)
assert.Equal(t, uint32(500000), newAssign[0].Count)
assert.NotEqual(t, segAssigns[0].SegID, newAssign[0].SegID)
// test assignment expiration
time.Sleep(3 * time.Second)
assignAfterExpiration, err := segMgr.AssignSegmentID(newReqs)
assert.Nil(t, err)
assert.NotNil(t, assignAfterExpiration)
assert.Equal(t, uint32(500000), assignAfterExpiration[0].Count)
assert.Equal(t, segAssigns[0].SegID, assignAfterExpiration[0].SegID)
// test invalid params
newReqs[0].CollName = "wrong_collname"
_, err = segMgr.AssignSegmentID(newReqs)
assert.Error(t, errors.Errorf("can not find collection with id=%d", collID), err)
newReqs[0].Count = 1000000
_, err = segMgr.AssignSegmentID(newReqs)
assert.Error(t, errors.Errorf("request with count %d need about %d mem size which is larger than segment threshold",
1000000, 1024*1000000), err)
}
func TestSegmentManager_SegmentStats(t *testing.T) {
setup()
defer teardown()
ts, err := segMgr.globalTSOAllocator()
assert.Nil(t, err)
err = mt.AddSegment(&pb.SegmentMeta{
SegmentID: 100,
err = mt.UpdateSegment(&pb.SegmentMeta{
SegmentID: results[0].SegID,
CollectionID: collID,
PartitionTag: partitionTag,
ChannelStart: 0,
ChannelEnd: 1,
OpenTime: ts,
CloseTime: timestamp,
NumRows: 400000,
MemSize: 500000,
})
assert.Nil(t, err)
stats := internalpb.QueryNodeStats{
MsgType: internalpb.MsgType_kQueryNodeStats,
PeerID: 1,
SegStats: []*internalpb.SegmentStats{
{SegmentID: 100, MemorySize: 2500000, NumRows: 25000, RecentlyModified: true},
tsMsg := &msgstream.TimeTickMsg{
BaseMsg: msgstream.BaseMsg{
BeginTimestamp: timestamp, EndTimestamp: timestamp, HashValues: []uint32{},
},
TimeTickMsg: internalpb.TimeTickMsg{
MsgType: internalpb.MsgType_kTimeTick,
PeerID: 1,
Timestamp: timestamp,
},
}
baseMsg := msgstream.BaseMsg{
BeginTimestamp: 0,
EndTimestamp: 0,
HashValues: []uint32{1},
}
msg := msgstream.QueryNodeStatsMsg{
QueryNodeStats: stats,
BaseMsg: baseMsg,
}
var tsMsg msgstream.TsMsg = &msg
msgPack := msgstream.MsgPack{
Msgs: make([]msgstream.TsMsg, 0),
}
msgPack.Msgs = append(msgPack.Msgs, tsMsg)
err = segMgr.HandleQueryNodeMsgPack(&msgPack)
syncWriteChan <- tsMsg
time.Sleep(300 * time.Millisecond)
segMeta, err := mt.GetSegmentByID(results[0].SegID)
assert.Nil(t, err)
segMeta, _ := mt.GetSegmentByID(100)
assert.Equal(t, int64(100), segMeta.SegmentID)
assert.Equal(t, int64(2500000), segMeta.MemSize)
assert.Equal(t, int64(25000), segMeta.NumRows)
// close segment
stats.SegStats[0].NumRows = 600000
stats.SegStats[0].MemorySize = int64(0.8 * segMgr.segmentThreshold)
err = segMgr.HandleQueryNodeMsgPack(&msgPack)
assert.Nil(t, err)
segMeta, _ = mt.GetSegmentByID(100)
assert.Equal(t, int64(100), segMeta.SegmentID)
assert.NotEqual(t, uint64(0), segMeta.CloseTime)
assert.NotEqualValues(t, 0, segMeta.CloseTime)
}
func startupMaster() {
func TestSegmentManager_RPC(t *testing.T) {
Init()
refreshMasterAddress()
etcdAddress := Params.EtcdAddress
rootPath := "/test/root"
ctx, cancel := context.WithCancel(context.TODO())
masterCancelFunc = cancel
defer cancel()
cli, err := clientv3.New(clientv3.Config{Endpoints: []string{etcdAddress}})
if err != nil {
panic(err)
}
assert.Nil(t, err)
_, err = cli.Delete(ctx, rootPath, clientv3.WithPrefix())
if err != nil {
panic(err)
}
assert.Nil(t, err)
Params = ParamTable{
Address: Params.Address,
Port: Params.Port,
@ -268,27 +207,13 @@ func startupMaster() {
DefaultPartitionTag: "_default",
}
master, err = CreateServer(ctx)
if err != nil {
panic(err)
}
collName := "test_coll"
partitionTag := "test_part"
master, err := CreateServer(ctx)
assert.Nil(t, err)
defer master.Close()
err = master.Run(int64(Params.Port))
if err != nil {
panic(err)
}
}
func shutdownMaster() {
masterCancelFunc()
master.Close()
}
func TestSegmentManager_RPC(t *testing.T) {
startupMaster()
defer shutdownMaster()
ctx, cancel := context.WithCancel(context.TODO())
defer cancel()
assert.Nil(t, err)
dialContext, err := grpc.DialContext(ctx, Params.Address, grpc.WithInsecure(), grpc.WithBlock())
assert.Nil(t, err)
defer dialContext.Close()
@ -297,7 +222,10 @@ func TestSegmentManager_RPC(t *testing.T) {
Name: collName,
Description: "test coll",
AutoID: false,
Fields: []*schemapb.FieldSchema{},
Fields: []*schemapb.FieldSchema{
{FieldID: 1, Name: "f1", IsPrimaryKey: false, DataType: schemapb.DataType_INT32},
{FieldID: 1, Name: "f1", IsPrimaryKey: false, DataType: schemapb.DataType_VECTOR_FLOAT, TypeParams: []*commonpb.KeyValuePair{{Key: "dim", Value: "128"}}},
},
}
schemaBytes, err := proto.Marshal(schema)
assert.Nil(t, err)
@ -329,13 +257,13 @@ func TestSegmentManager_RPC(t *testing.T) {
},
})
assert.Nil(t, err)
assert.Equal(t, commonpb.ErrorCode_SUCCESS, resp.Status.ErrorCode)
assert.EqualValues(t, commonpb.ErrorCode_SUCCESS, resp.Status.ErrorCode)
assignments := resp.GetPerChannelAssignment()
assert.Equal(t, 1, len(assignments))
assert.Equal(t, collName, assignments[0].CollName)
assert.Equal(t, partitionTag, assignments[0].PartitionTag)
assert.Equal(t, int32(0), assignments[0].ChannelID)
assert.Equal(t, uint32(10000), assignments[0].Count)
assert.EqualValues(t, 1, len(assignments))
assert.EqualValues(t, collName, assignments[0].CollName)
assert.EqualValues(t, partitionTag, assignments[0].PartitionTag)
assert.EqualValues(t, int32(0), assignments[0].ChannelID)
assert.EqualValues(t, uint32(10000), assignments[0].Count)
// test stats
segID := assignments[0].SegID
@ -369,6 +297,6 @@ func TestSegmentManager_RPC(t *testing.T) {
time.Sleep(500 * time.Millisecond)
segMeta, err := master.metaTable.GetSegmentByID(segID)
assert.Nil(t, err)
assert.NotEqual(t, uint64(0), segMeta.GetCloseTime())
assert.Equal(t, int64(600000000), segMeta.GetMemSize())
assert.EqualValues(t, 1000000, segMeta.GetNumRows())
assert.EqualValues(t, int64(600000000), segMeta.GetMemSize())
}

View File

@ -0,0 +1,58 @@
package master
import (
"github.com/zilliztech/milvus-distributed/internal/errors"
"github.com/zilliztech/milvus-distributed/internal/msgstream"
"github.com/zilliztech/milvus-distributed/internal/proto/internalpb"
)
type StatsProcessor struct {
metaTable *metaTable
segmentThreshold float64
segmentThresholdFactor float64
globalTSOAllocator func() (Timestamp, error)
}
func (processor *StatsProcessor) ProcessQueryNodeStats(msgPack *msgstream.MsgPack) error {
for _, msg := range msgPack.Msgs {
statsMsg, ok := msg.(*msgstream.QueryNodeStatsMsg)
if !ok {
return errors.Errorf("Type of message is not QueryNodeSegStatsMsg")
}
for _, segStat := range statsMsg.GetSegStats() {
err := processor.processSegmentStat(segStat)
if err != nil {
return err
}
}
}
return nil
}
func (processor *StatsProcessor) processSegmentStat(segStats *internalpb.SegmentStats) error {
if !segStats.GetRecentlyModified() {
return nil
}
segID := segStats.GetSegmentID()
segMeta, err := processor.metaTable.GetSegmentByID(segID)
if err != nil {
return err
}
segMeta.NumRows = segStats.NumRows
segMeta.MemSize = segStats.MemorySize
return processor.metaTable.UpdateSegment(segMeta)
}
func NewStatsProcessor(mt *metaTable, globalTSOAllocator func() (Timestamp, error)) *StatsProcessor {
return &StatsProcessor{
metaTable: mt,
segmentThreshold: Params.SegmentSize * 1024 * 1024,
segmentThresholdFactor: Params.SegmentSizeFactor,
globalTSOAllocator: globalTSOAllocator,
}
}

View File

@ -21,6 +21,9 @@ type timeSyncMsgProducer struct {
ctx context.Context
cancel context.CancelFunc
proxyWatchers []chan *ms.TimeTickMsg
writeNodeWatchers []chan *ms.TimeTickMsg
}
func NewTimeSyncMsgProducer(ctx context.Context) (*timeSyncMsgProducer, error) {
@ -47,7 +50,15 @@ func (syncMsgProducer *timeSyncMsgProducer) SetK2sSyncStream(k2sSync ms.MsgStrea
syncMsgProducer.k2sSyncStream = k2sSync
}
func (syncMsgProducer *timeSyncMsgProducer) broadcastMsg(barrier TimeTickBarrier, streams []ms.MsgStream) error {
func (syncMsgProducer *timeSyncMsgProducer) WatchProxyTtBarrier(watcher chan *ms.TimeTickMsg) {
syncMsgProducer.proxyWatchers = append(syncMsgProducer.proxyWatchers, watcher)
}
func (syncMsgProducer *timeSyncMsgProducer) WatchWriteNodeTtBarrier(watcher chan *ms.TimeTickMsg) {
syncMsgProducer.writeNodeWatchers = append(syncMsgProducer.writeNodeWatchers, watcher)
}
func (syncMsgProducer *timeSyncMsgProducer) broadcastMsg(barrier TimeTickBarrier, streams []ms.MsgStream, channels []chan *ms.TimeTickMsg) error {
for {
select {
case <-syncMsgProducer.ctx.Done():
@ -79,6 +90,10 @@ func (syncMsgProducer *timeSyncMsgProducer) broadcastMsg(barrier TimeTickBarrier
for _, stream := range streams {
err = stream.Broadcast(&msgPack)
}
for _, channel := range channels {
channel <- timeTickMsg
}
if err != nil {
return err
}
@ -97,8 +112,8 @@ func (syncMsgProducer *timeSyncMsgProducer) Start() error {
return err
}
go syncMsgProducer.broadcastMsg(syncMsgProducer.proxyTtBarrier, []ms.MsgStream{syncMsgProducer.dmSyncStream, syncMsgProducer.ddSyncStream})
go syncMsgProducer.broadcastMsg(syncMsgProducer.writeNodeTtBarrier, []ms.MsgStream{syncMsgProducer.k2sSyncStream})
go syncMsgProducer.broadcastMsg(syncMsgProducer.proxyTtBarrier, []ms.MsgStream{syncMsgProducer.dmSyncStream, syncMsgProducer.ddSyncStream}, syncMsgProducer.proxyWatchers)
go syncMsgProducer.broadcastMsg(syncMsgProducer.writeNodeTtBarrier, []ms.MsgStream{syncMsgProducer.k2sSyncStream}, syncMsgProducer.writeNodeWatchers)
return nil
}

View File

@ -0,0 +1,36 @@
package typeutil
import (
"strconv"
"github.com/zilliztech/milvus-distributed/internal/proto/schemapb"
)
func EstimateSizePerRecord(schema *schemapb.CollectionSchema) (int, error) {
res := 0
for _, fs := range schema.Fields {
switch fs.DataType {
case schemapb.DataType_BOOL, schemapb.DataType_INT8:
res++
case schemapb.DataType_INT16:
res += 2
case schemapb.DataType_INT32, schemapb.DataType_FLOAT:
res += 4
case schemapb.DataType_INT64, schemapb.DataType_DOUBLE:
res += 8
case schemapb.DataType_STRING:
res += 125 // todo find a better way to estimate string type
case schemapb.DataType_VECTOR_BINARY, schemapb.DataType_VECTOR_FLOAT:
for _, kv := range fs.TypeParams {
if kv.Key == "dim" {
v, err := strconv.Atoi(kv.Value)
if err != nil {
return -1, err
}
res += v
}
}
}
}
return res, nil
}