diff --git a/internal/indexcoord/index_builder_test.go b/internal/indexcoord/index_builder_test.go index 0c4b033902..81008a3242 100644 --- a/internal/indexcoord/index_builder_test.go +++ b/internal/indexcoord/index_builder_test.go @@ -166,6 +166,7 @@ func TestIndexBuilder(t *testing.T) { Err: false, }, nodeManager: &NodeManager{ + ctx: ctx, nodeClients: map[UniqueID]types.IndexNode{ 4: &indexnode.Mock{ Err: false, @@ -295,7 +296,9 @@ func TestIndexBuilder_Error(t *testing.T) { Fail: false, Err: false, }, - nodeManager: &NodeManager{}, + nodeManager: &NodeManager{ + ctx: ctx, + }, } mt := &metaTable{ indexBuildID2Meta: map[UniqueID]*Meta{ @@ -340,6 +343,7 @@ func TestIndexBuilder_Error(t *testing.T) { Err: false, }, nodeManager: &NodeManager{ + ctx: ctx, nodeClients: map[UniqueID]types.IndexNode{ 1: &indexnode.Mock{ Err: false, @@ -391,6 +395,7 @@ func TestIndexBuilder_Error(t *testing.T) { Err: true, }, nodeManager: &NodeManager{ + ctx: ctx, nodeClients: map[UniqueID]types.IndexNode{ 1: &indexnode.Mock{ Err: false, diff --git a/internal/indexcoord/node_manager.go b/internal/indexcoord/node_manager.go index 75ec90fa38..8fb71d2013 100644 --- a/internal/indexcoord/node_manager.go +++ b/internal/indexcoord/node_manager.go @@ -20,17 +20,14 @@ import ( "context" "sync" - "github.com/milvus-io/milvus/internal/proto/commonpb" - "github.com/milvus-io/milvus/internal/proto/indexpb" - - "github.com/milvus-io/milvus/internal/metrics" - - "go.uber.org/zap" - grpcindexnodeclient "github.com/milvus-io/milvus/internal/distributed/indexnode/client" "github.com/milvus-io/milvus/internal/log" + "github.com/milvus-io/milvus/internal/metrics" + "github.com/milvus-io/milvus/internal/proto/commonpb" + "github.com/milvus-io/milvus/internal/proto/indexpb" "github.com/milvus-io/milvus/internal/proto/milvuspb" "github.com/milvus-io/milvus/internal/types" + "go.uber.org/zap" ) // NodeManager is used by IndexCoord to manage the client of IndexNode. @@ -56,7 +53,7 @@ func NewNodeManager(ctx context.Context) *NodeManager { // setClient sets IndexNode client to node manager. func (nm *NodeManager) setClient(nodeID UniqueID, client types.IndexNode) error { log.Debug("IndexCoord NodeManager setClient", zap.Int64("nodeID", nodeID)) - defer log.Debug("IndexNode NodeManager setclient success", zap.Any("nodeID", nodeID)) + defer log.Debug("IndexNode NodeManager setClient success", zap.Any("nodeID", nodeID)) item := &PQItem{ key: nodeID, priority: 0, @@ -105,32 +102,71 @@ func (nm *NodeManager) AddNode(nodeID UniqueID, address string) error { // PeekClient peeks the client with the least load. func (nm *NodeManager) PeekClient(meta *Meta) (UniqueID, types.IndexNode) { - nm.lock.RLock() - defer nm.lock.RUnlock() + allClients := nm.GetAllClients() - if len(nm.nodeClients) == 0 { + if len(allClients) == 0 { log.Error("there is no IndexNode online") return -1, nil } - for nodeID, client := range nm.nodeClients { - resp, err := client.GetTaskSlots(nm.ctx, &indexpb.GetTaskSlotsRequest{}) - if err != nil { - log.Warn("get IndexNode slots failed", zap.Int64("nodeID", nodeID), zap.Error(err)) - continue - } - if resp.Status.ErrorCode != commonpb.ErrorCode_Success { - log.Warn("get IndexNode slots failed", zap.Int64("nodeID", nodeID), - zap.String("reason", resp.Status.Reason)) - continue - } - if resp.Slots > 0 { - return nodeID, client - } + + // Note: In order to quickly end other goroutines, an error is returned when the client is successfully selected + ctx, cancel := context.WithCancel(nm.ctx) + var ( + peekNodeID = UniqueID(0) + nodeMutex = sync.Mutex{} + wg = sync.WaitGroup{} + ) + + for nodeID, client := range allClients { + nodeID := nodeID + client := client + wg.Add(1) + go func() { + defer wg.Done() + resp, err := client.GetTaskSlots(ctx, &indexpb.GetTaskSlotsRequest{}) + if err != nil { + log.Warn("get IndexNode slots failed", zap.Int64("nodeID", nodeID), zap.Error(err)) + return + } + if resp.Status.ErrorCode != commonpb.ErrorCode_Success { + log.Warn("get IndexNode slots failed", zap.Int64("nodeID", nodeID), + zap.String("reason", resp.Status.Reason)) + return + } + if resp.Slots > 0 { + nodeMutex.Lock() + defer nodeMutex.Unlock() + log.Info("peek client success", zap.Int64("nodeID", nodeID)) + if peekNodeID == 0 { + peekNodeID = nodeID + } + cancel() + // Note: In order to quickly end other goroutines, an error is returned when the client is successfully selected + return + } + }() + } + wg.Wait() + cancel() + if peekNodeID != 0 { + return peekNodeID, allClients[peekNodeID] } return 0, nil } +func (nm *NodeManager) GetAllClients() map[UniqueID]types.IndexNode { + nm.lock.RLock() + defer nm.lock.RUnlock() + + allClients := make(map[UniqueID]types.IndexNode, len(nm.nodeClients)) + for nodeID, client := range nm.nodeClients { + allClients[nodeID] = client + } + + return allClients +} + // indexNodeGetMetricsResponse record the metrics information of IndexNode. type indexNodeGetMetricsResponse struct { resp *milvuspb.GetMetricsResponse diff --git a/internal/indexcoord/node_manager_test.go b/internal/indexcoord/node_manager_test.go index bfbf3343f4..db70fae5cd 100644 --- a/internal/indexcoord/node_manager_test.go +++ b/internal/indexcoord/node_manager_test.go @@ -18,40 +18,145 @@ package indexcoord import ( "context" + "errors" "testing" + "github.com/milvus-io/milvus/internal/indexnode" "github.com/milvus-io/milvus/internal/proto/commonpb" "github.com/milvus-io/milvus/internal/proto/indexpb" "github.com/milvus-io/milvus/internal/proto/schemapb" + "github.com/milvus-io/milvus/internal/types" "github.com/stretchr/testify/assert" ) func TestNodeManager_PeekClient(t *testing.T) { - nm := NewNodeManager(context.Background()) - meta := &Meta{ - indexMeta: &indexpb.IndexMeta{ - Req: &indexpb.BuildIndexRequest{ - DataPaths: []string{"PeekClient-1", "PeekClient-2"}, - NumRows: 1000, - TypeParams: []*commonpb.KeyValuePair{ - { - Key: "dim", - Value: "128", + t.Run("success", func(t *testing.T) { + nm := NewNodeManager(context.Background()) + meta := &Meta{ + indexMeta: &indexpb.IndexMeta{ + Req: &indexpb.BuildIndexRequest{ + DataPaths: []string{"PeekClient-1", "PeekClient-2"}, + NumRows: 1000, + TypeParams: []*commonpb.KeyValuePair{ + { + Key: "dim", + Value: "128", + }, + }, + FieldSchema: &schemapb.FieldSchema{ + DataType: schemapb.DataType_FloatVector, }, }, - FieldSchema: &schemapb.FieldSchema{ - DataType: schemapb.DataType_FloatVector, + }, + } + nodeID, client := nm.PeekClient(meta) + assert.Equal(t, int64(-1), nodeID) + assert.Nil(t, client) + err := nm.AddNode(1, "indexnode-1") + assert.Nil(t, err) + nm.pq.SetMemory(1, 100) + nodeID2, client2 := nm.PeekClient(meta) + assert.Equal(t, int64(0), nodeID2) + assert.Nil(t, client2) + }) + + t.Run("multiple unavailable IndexNode", func(t *testing.T) { + nm := &NodeManager{ + ctx: context.TODO(), + nodeClients: map[UniqueID]types.IndexNode{ + 1: &indexnode.MockIndexNode{ + GetTaskSlotsMock: func(ctx context.Context, req *indexpb.GetTaskSlotsRequest) (*indexpb.GetTaskSlotsResponse, error) { + return &indexpb.GetTaskSlotsResponse{ + Status: &commonpb.Status{ + ErrorCode: commonpb.ErrorCode_UnexpectedError, + }, + }, errors.New("error") + }, + }, + 2: &indexnode.MockIndexNode{ + GetTaskSlotsMock: func(ctx context.Context, req *indexpb.GetTaskSlotsRequest) (*indexpb.GetTaskSlotsResponse, error) { + return &indexpb.GetTaskSlotsResponse{ + Status: &commonpb.Status{ + ErrorCode: commonpb.ErrorCode_UnexpectedError, + }, + }, errors.New("error") + }, + }, + 3: &indexnode.MockIndexNode{ + GetTaskSlotsMock: func(ctx context.Context, req *indexpb.GetTaskSlotsRequest) (*indexpb.GetTaskSlotsResponse, error) { + return &indexpb.GetTaskSlotsResponse{ + Status: &commonpb.Status{ + ErrorCode: commonpb.ErrorCode_UnexpectedError, + }, + }, errors.New("error") + }, + }, + 4: &indexnode.MockIndexNode{ + GetTaskSlotsMock: func(ctx context.Context, req *indexpb.GetTaskSlotsRequest) (*indexpb.GetTaskSlotsResponse, error) { + return &indexpb.GetTaskSlotsResponse{ + Status: &commonpb.Status{ + ErrorCode: commonpb.ErrorCode_UnexpectedError, + }, + }, errors.New("error") + }, + }, + 5: &indexnode.MockIndexNode{ + GetTaskSlotsMock: func(ctx context.Context, req *indexpb.GetTaskSlotsRequest) (*indexpb.GetTaskSlotsResponse, error) { + return &indexpb.GetTaskSlotsResponse{ + Status: &commonpb.Status{ + ErrorCode: commonpb.ErrorCode_UnexpectedError, + Reason: "fail reason", + }, + }, nil + }, + }, + 6: &indexnode.MockIndexNode{ + GetTaskSlotsMock: func(ctx context.Context, req *indexpb.GetTaskSlotsRequest) (*indexpb.GetTaskSlotsResponse, error) { + return &indexpb.GetTaskSlotsResponse{ + Status: &commonpb.Status{ + ErrorCode: commonpb.ErrorCode_UnexpectedError, + Reason: "fail reason", + }, + }, nil + }, + }, + 7: &indexnode.MockIndexNode{ + GetTaskSlotsMock: func(ctx context.Context, req *indexpb.GetTaskSlotsRequest) (*indexpb.GetTaskSlotsResponse, error) { + return &indexpb.GetTaskSlotsResponse{ + Status: &commonpb.Status{ + ErrorCode: commonpb.ErrorCode_UnexpectedError, + Reason: "fail reason", + }, + }, nil + }, + }, + 8: &indexnode.MockIndexNode{ + GetTaskSlotsMock: func(ctx context.Context, req *indexpb.GetTaskSlotsRequest) (*indexpb.GetTaskSlotsResponse, error) { + return &indexpb.GetTaskSlotsResponse{ + Slots: 1, + Status: &commonpb.Status{ + ErrorCode: commonpb.ErrorCode_Success, + Reason: "", + }, + }, nil + }, + }, + 9: &indexnode.MockIndexNode{ + GetTaskSlotsMock: func(ctx context.Context, req *indexpb.GetTaskSlotsRequest) (*indexpb.GetTaskSlotsResponse, error) { + return &indexpb.GetTaskSlotsResponse{ + Slots: 10, + Status: &commonpb.Status{ + ErrorCode: commonpb.ErrorCode_Success, + Reason: "", + }, + }, nil + }, }, }, - }, - } - nodeID, client := nm.PeekClient(meta) - assert.Equal(t, int64(-1), nodeID) - assert.Nil(t, client) - err := nm.AddNode(1, "indexnode-1") - assert.Nil(t, err) - nm.pq.SetMemory(1, 100) - nodeID2, client2 := nm.PeekClient(meta) - assert.Equal(t, int64(0), nodeID2) - assert.Nil(t, client2) + } + + nodeID, client := nm.PeekClient(&Meta{}) + assert.NotNil(t, client) + assert.Contains(t, []UniqueID{8, 9}, nodeID) + }) } diff --git a/internal/indexnode/indexnode_mock.go b/internal/indexnode/indexnode_mock.go index d57381f3c9..289a067168 100644 --- a/internal/indexnode/indexnode_mock.go +++ b/internal/indexnode/indexnode_mock.go @@ -22,6 +22,8 @@ import ( "fmt" "sync" + "github.com/milvus-io/milvus/internal/types" + "go.uber.org/zap" "github.com/golang/protobuf/proto" @@ -39,6 +41,7 @@ import ( ) // Mock is an alternative to IndexNode, it will return specific results based on specific parameters. +// Deprecated, use MockIndexNode type Mock struct { Build bool Failure bool @@ -386,3 +389,23 @@ func getMockSystemInfoMetrics( ComponentName: metricsinfo.ConstructComponentName(typeutil.IndexNodeRole, Params.IndexNodeCfg.GetNodeID()), }, nil } + +type MockIndexNode struct { + types.IndexNode + + CreateIndexMock func(ctx context.Context, req *indexpb.CreateIndexRequest) (*commonpb.Status, error) + GetTaskSlotsMock func(ctx context.Context, req *indexpb.GetTaskSlotsRequest) (*indexpb.GetTaskSlotsResponse, error) + GetMetricsMock func(ctx context.Context, req *milvuspb.GetMetricsRequest) (*milvuspb.GetMetricsResponse, error) +} + +func (min *MockIndexNode) CreateIndex(ctx context.Context, req *indexpb.CreateIndexRequest) (*commonpb.Status, error) { + return min.CreateIndexMock(ctx, req) +} + +func (min *MockIndexNode) GetTaskSlots(ctx context.Context, req *indexpb.GetTaskSlotsRequest) (*indexpb.GetTaskSlotsResponse, error) { + return min.GetTaskSlotsMock(ctx, req) +} + +func (min *MockIndexNode) GetMetrics(ctx context.Context, req *milvuspb.GetMetricsRequest) (*milvuspb.GetMetricsResponse, error) { + return min.GetMetricsMock(ctx, req) +}