Refactor task interface of Proxy and fix the wait logic of task

Signed-off-by: dragondriver <jiquan.long@zilliz.com>
This commit is contained in:
dragondriver 2020-11-11 16:00:56 +08:00 committed by yefu.chen
parent c442d50c08
commit fb1e24ade8
5 changed files with 159 additions and 151 deletions

View File

@ -35,21 +35,18 @@ func (p *Proxy) Insert(ctx context.Context, in *servicepb.RowBatch) (*servicepb.
defer it.cancel()
var t task = it
p.taskSch.DmQueue.Enqueue(&t)
for {
select {
case <-ctx.Done():
log.Print("insert timeout!")
return &servicepb.IntegerRangeResponse{
Status: &commonpb.Status{
ErrorCode: commonpb.ErrorCode_UNEXPECTED_ERROR,
Reason: "insert timeout!",
},
}, errors.New("insert timeout!")
case result := <-it.resultChan:
return result, nil
}
p.taskSch.DmQueue.Enqueue(it)
select {
case <-ctx.Done():
log.Print("insert timeout!")
return &servicepb.IntegerRangeResponse{
Status: &commonpb.Status{
ErrorCode: commonpb.ErrorCode_UNEXPECTED_ERROR,
Reason: "insert timeout!",
},
}, errors.New("insert timeout!")
case result := <-it.resultChan:
return result, nil
}
}
@ -69,19 +66,16 @@ func (p *Proxy) CreateCollection(ctx context.Context, req *schemapb.CollectionSc
cct.ctx, cct.cancel = context.WithCancel(ctx)
defer cct.cancel()
var t task = cct
p.taskSch.DdQueue.Enqueue(&t)
for {
select {
case <-ctx.Done():
log.Print("create collection timeout!")
return &commonpb.Status{
ErrorCode: commonpb.ErrorCode_UNEXPECTED_ERROR,
Reason: "create collection timeout!",
}, errors.New("create collection timeout!")
case result := <-cct.resultChan:
return result, nil
}
p.taskSch.DdQueue.Enqueue(cct)
select {
case <-ctx.Done():
log.Print("create collection timeout!")
return &commonpb.Status{
ErrorCode: commonpb.ErrorCode_UNEXPECTED_ERROR,
Reason: "create collection timeout!",
}, errors.New("create collection timeout!")
case result := <-cct.resultChan:
return result, nil
}
}
@ -102,21 +96,18 @@ func (p *Proxy) Search(ctx context.Context, req *servicepb.Query) (*servicepb.Qu
qt.SearchRequest.Query.Value = queryBytes
defer qt.cancel()
var t task = qt
p.taskSch.DqQueue.Enqueue(&t)
for {
select {
case <-ctx.Done():
log.Print("query timeout!")
return &servicepb.QueryResult{
Status: &commonpb.Status{
ErrorCode: commonpb.ErrorCode_UNEXPECTED_ERROR,
Reason: "query timeout!",
},
}, errors.New("query timeout!")
case result := <-qt.resultChan:
return result, nil
}
p.taskSch.DqQueue.Enqueue(qt)
select {
case <-ctx.Done():
log.Print("query timeout!")
return &servicepb.QueryResult{
Status: &commonpb.Status{
ErrorCode: commonpb.ErrorCode_UNEXPECTED_ERROR,
Reason: "query timeout!",
},
}, errors.New("query timeout!")
case result := <-qt.resultChan:
return result, nil
}
}
@ -134,19 +125,16 @@ func (p *Proxy) DropCollection(ctx context.Context, req *servicepb.CollectionNam
dct.ctx, dct.cancel = context.WithCancel(ctx)
defer dct.cancel()
var t task = dct
p.taskSch.DdQueue.Enqueue(&t)
for {
select {
case <-ctx.Done():
log.Print("create collection timeout!")
return &commonpb.Status{
ErrorCode: commonpb.ErrorCode_UNEXPECTED_ERROR,
Reason: "create collection timeout!",
}, errors.New("create collection timeout!")
case result := <-dct.resultChan:
return result, nil
}
p.taskSch.DdQueue.Enqueue(dct)
select {
case <-ctx.Done():
log.Print("create collection timeout!")
return &commonpb.Status{
ErrorCode: commonpb.ErrorCode_UNEXPECTED_ERROR,
Reason: "create collection timeout!",
}, errors.New("create collection timeout!")
case result := <-dct.resultChan:
return result, nil
}
}
@ -164,22 +152,19 @@ func (p *Proxy) HasCollection(ctx context.Context, req *servicepb.CollectionName
hct.ctx, hct.cancel = context.WithCancel(ctx)
defer hct.cancel()
var t task = hct
p.taskSch.DqQueue.Enqueue(&t)
for {
select {
case <-ctx.Done():
log.Print("has collection timeout!")
return &servicepb.BoolResponse{
Status: &commonpb.Status{
ErrorCode: commonpb.ErrorCode_UNEXPECTED_ERROR,
Reason: "has collection timeout!",
},
Value: false,
}, errors.New("has collection timeout!")
case result := <-hct.resultChan:
return result, nil
}
p.taskSch.DqQueue.Enqueue(hct)
select {
case <-ctx.Done():
log.Print("has collection timeout!")
return &servicepb.BoolResponse{
Status: &commonpb.Status{
ErrorCode: commonpb.ErrorCode_UNEXPECTED_ERROR,
Reason: "has collection timeout!",
},
Value: false,
}, errors.New("has collection timeout!")
case result := <-hct.resultChan:
return result, nil
}
}
@ -197,21 +182,18 @@ func (p *Proxy) DescribeCollection(ctx context.Context, req *servicepb.Collectio
dct.ctx, dct.cancel = context.WithCancel(ctx)
defer dct.cancel()
var t task = dct
p.taskSch.DqQueue.Enqueue(&t)
for {
select {
case <-ctx.Done():
log.Print("has collection timeout!")
return &servicepb.CollectionDescription{
Status: &commonpb.Status{
ErrorCode: commonpb.ErrorCode_UNEXPECTED_ERROR,
Reason: "describe collection timeout!",
},
}, errors.New("describe collection timeout!")
case result := <-dct.resultChan:
return result, nil
}
p.taskSch.DqQueue.Enqueue(dct)
select {
case <-ctx.Done():
log.Print("has collection timeout!")
return &servicepb.CollectionDescription{
Status: &commonpb.Status{
ErrorCode: commonpb.ErrorCode_UNEXPECTED_ERROR,
Reason: "describe collection timeout!",
},
}, errors.New("describe collection timeout!")
case result := <-dct.resultChan:
return result, nil
}
}
@ -228,21 +210,18 @@ func (p *Proxy) ShowCollections(ctx context.Context, req *commonpb.Empty) (*serv
sct.ctx, sct.cancel = context.WithCancel(ctx)
defer sct.cancel()
var t task = sct
p.taskSch.DqQueue.Enqueue(&t)
for {
select {
case <-ctx.Done():
log.Print("show collections timeout!")
return &servicepb.StringListResponse{
Status: &commonpb.Status{
ErrorCode: commonpb.ErrorCode_UNEXPECTED_ERROR,
Reason: "show collections timeout!",
},
}, errors.New("show collections timeout!")
case result := <-sct.resultChan:
return result, nil
}
p.taskSch.DqQueue.Enqueue(sct)
select {
case <-ctx.Done():
log.Print("show collections timeout!")
return &servicepb.StringListResponse{
Status: &commonpb.Status{
ErrorCode: commonpb.ErrorCode_UNEXPECTED_ERROR,
Reason: "show collections timeout!",
},
}, errors.New("show collections timeout!")
case result := <-sct.resultChan:
return result, nil
}
}

View File

@ -2,6 +2,7 @@ package proxy
import (
"context"
"google.golang.org/grpc"
"log"
"math/rand"
"net"
@ -14,7 +15,6 @@ import (
"github.com/zilliztech/milvus-distributed/internal/proto/masterpb"
"github.com/zilliztech/milvus-distributed/internal/proto/servicepb"
"github.com/zilliztech/milvus-distributed/internal/util/typeutil"
"google.golang.org/grpc"
)
type UniqueID = typeutil.UniqueID
@ -157,7 +157,7 @@ func (p *Proxy) queryResultLoop() {
if len(queryResultBuf[reqId]) == 4 {
// TODO: use the number of query node instead
t := p.taskSch.getTaskByReqId(reqId)
qt := (*t).(*QueryTask)
qt := t.(*QueryTask)
qt.resultBuf <- queryResultBuf[reqId]
delete(queryResultBuf, reqId)
}

View File

@ -11,7 +11,7 @@ import (
type BaseTaskQueue struct {
unissuedTasks *list.List
activeTasks map[Timestamp]*task
activeTasks map[Timestamp]task
utLock sync.Mutex
atLock sync.Mutex
}
@ -24,23 +24,23 @@ func (queue *BaseTaskQueue) Empty() bool {
return queue.unissuedTasks.Len() <= 0 && len(queue.activeTasks) <= 0
}
func (queue *BaseTaskQueue) AddUnissuedTask(t *task) {
func (queue *BaseTaskQueue) AddUnissuedTask(t task) {
queue.utLock.Lock()
defer queue.utLock.Unlock()
queue.unissuedTasks.PushBack(t)
}
func (queue *BaseTaskQueue) FrontUnissuedTask() *task {
func (queue *BaseTaskQueue) FrontUnissuedTask() task {
queue.utLock.Lock()
defer queue.utLock.Unlock()
if queue.unissuedTasks.Len() <= 0 {
log.Fatal("sorry, but the unissued task list is empty!")
return nil
}
return queue.unissuedTasks.Front().Value.(*task)
return queue.unissuedTasks.Front().Value.(task)
}
func (queue *BaseTaskQueue) PopUnissuedTask() *task {
func (queue *BaseTaskQueue) PopUnissuedTask() task {
queue.utLock.Lock()
defer queue.utLock.Unlock()
if queue.unissuedTasks.Len() <= 0 {
@ -48,13 +48,13 @@ func (queue *BaseTaskQueue) PopUnissuedTask() *task {
return nil
}
ft := queue.unissuedTasks.Front()
return queue.unissuedTasks.Remove(ft).(*task)
return queue.unissuedTasks.Remove(ft).(task)
}
func (queue *BaseTaskQueue) AddActiveTask(t *task) {
func (queue *BaseTaskQueue) AddActiveTask(t task) {
queue.atLock.Lock()
defer queue.atLock.Lock()
ts := (*t).EndTs()
ts := t.EndTs()
_, ok := queue.activeTasks[ts]
if ok {
log.Fatalf("task with timestamp %v already in active task list!", ts)
@ -62,7 +62,7 @@ func (queue *BaseTaskQueue) AddActiveTask(t *task) {
queue.activeTasks[ts] = t
}
func (queue *BaseTaskQueue) PopActiveTask(ts Timestamp) *task {
func (queue *BaseTaskQueue) PopActiveTask(ts Timestamp) task {
queue.atLock.Lock()
defer queue.atLock.Lock()
t, ok := queue.activeTasks[ts]
@ -74,19 +74,19 @@ func (queue *BaseTaskQueue) PopActiveTask(ts Timestamp) *task {
return nil
}
func (queue *BaseTaskQueue) getTaskByReqId(reqId UniqueID) *task {
func (queue *BaseTaskQueue) getTaskByReqId(reqId UniqueID) task {
queue.utLock.Lock()
defer queue.utLock.Lock()
for e := queue.unissuedTasks.Front(); e != nil; e = e.Next() {
if (*(e.Value.(*task))).Id() == reqId {
return e.Value.(*task)
if e.Value.(task).Id() == reqId {
return e.Value.(task)
}
}
queue.atLock.Lock()
defer queue.atLock.Unlock()
for ats := range queue.activeTasks {
if (*(queue.activeTasks[ats])).Id() == reqId {
if queue.activeTasks[ats].Id() == reqId {
return queue.activeTasks[ats]
}
}
@ -98,7 +98,7 @@ func (queue *BaseTaskQueue) TaskDoneTest(ts Timestamp) bool {
queue.utLock.Lock()
defer queue.utLock.Unlock()
for e := queue.unissuedTasks.Front(); e != nil; e = e.Next() {
if (*(e.Value.(*task))).EndTs() >= ts {
if e.Value.(task).EndTs() >= ts {
return false
}
}
@ -114,20 +114,20 @@ func (queue *BaseTaskQueue) TaskDoneTest(ts Timestamp) bool {
return true
}
type ddTaskQueue struct {
type DdTaskQueue struct {
BaseTaskQueue
lock sync.Mutex
}
type dmTaskQueue struct {
type DmTaskQueue struct {
BaseTaskQueue
}
type dqTaskQueue struct {
type DqTaskQueue struct {
BaseTaskQueue
}
func (queue *ddTaskQueue) Enqueue(t *task) error {
func (queue *DdTaskQueue) Enqueue(t task) error {
queue.lock.Lock()
defer queue.lock.Unlock()
// TODO: set Ts, ReqId, ProxyId
@ -135,22 +135,49 @@ func (queue *ddTaskQueue) Enqueue(t *task) error {
return nil
}
func (queue *dmTaskQueue) Enqueue(t *task) error {
func (queue *DmTaskQueue) Enqueue(t task) error {
// TODO: set Ts, ReqId, ProxyId
queue.AddUnissuedTask(t)
return nil
}
func (queue *dqTaskQueue) Enqueue(t *task) error {
func (queue *DqTaskQueue) Enqueue(t task) error {
// TODO: set Ts, ReqId, ProxyId
queue.AddUnissuedTask(t)
return nil
}
func NewDdTaskQueue() *DdTaskQueue {
return &DdTaskQueue{
BaseTaskQueue: BaseTaskQueue{
unissuedTasks: list.New(),
activeTasks: make(map[Timestamp]task),
},
}
}
func NewDmTaskQueue() *DmTaskQueue {
return &DmTaskQueue{
BaseTaskQueue: BaseTaskQueue{
unissuedTasks: list.New(),
activeTasks: make(map[Timestamp]task),
},
}
}
func NewDqTaskQueue() *DqTaskQueue {
return &DqTaskQueue{
BaseTaskQueue: BaseTaskQueue{
unissuedTasks: list.New(),
activeTasks: make(map[Timestamp]task),
},
}
}
type TaskScheduler struct {
DdQueue *ddTaskQueue
DmQueue *dmTaskQueue
DqQueue *dqTaskQueue
DdQueue *DdTaskQueue
DmQueue *DmTaskQueue
DqQueue *DqTaskQueue
idAllocator *allocator.IdAllocator
tsoAllocator *allocator.TimestampAllocator
@ -165,6 +192,9 @@ func NewTaskScheduler(ctx context.Context,
tsoAllocator *allocator.TimestampAllocator) (*TaskScheduler, error) {
ctx1, cancel := context.WithCancel(ctx)
s := &TaskScheduler{
DdQueue: NewDdTaskQueue(),
DmQueue: NewDmTaskQueue(),
DqQueue: NewDqTaskQueue(),
idAllocator: idAllocator,
tsoAllocator: tsoAllocator,
ctx: ctx1,
@ -174,19 +204,19 @@ func NewTaskScheduler(ctx context.Context,
return s, nil
}
func (sched *TaskScheduler) scheduleDdTask() *task {
func (sched *TaskScheduler) scheduleDdTask() task {
return sched.DdQueue.PopUnissuedTask()
}
func (sched *TaskScheduler) scheduleDmTask() *task {
func (sched *TaskScheduler) scheduleDmTask() task {
return sched.DmQueue.PopUnissuedTask()
}
func (sched *TaskScheduler) scheduleDqTask() *task {
func (sched *TaskScheduler) scheduleDqTask() task {
return sched.DqQueue.PopUnissuedTask()
}
func (sched *TaskScheduler) getTaskByReqId(reqId UniqueID) *task {
func (sched *TaskScheduler) getTaskByReqId(reqId UniqueID) task {
if t := sched.DdQueue.getTaskByReqId(reqId); t != nil {
return t
}
@ -211,22 +241,22 @@ func (sched *TaskScheduler) definitionLoop() {
//sched.DdQueue.atLock.Lock()
t := sched.scheduleDdTask()
err := (*t).PreExecute()
err := t.PreExecute()
if err != nil {
return
}
err = (*t).Execute()
err = t.Execute()
if err != nil {
log.Printf("execute definition task failed, error = %v", err)
}
(*t).Notify(err)
t.Notify(err)
sched.DdQueue.AddActiveTask(t)
(*t).WaitToFinish()
(*t).PostExecute()
t.WaitToFinish()
t.PostExecute()
sched.DdQueue.PopActiveTask((*t).EndTs())
sched.DdQueue.PopActiveTask(t.EndTs())
}
}
@ -242,27 +272,27 @@ func (sched *TaskScheduler) manipulationLoop() {
sched.DmQueue.atLock.Lock()
t := sched.scheduleDmTask()
if err := (*t).PreExecute(); err != nil {
if err := t.PreExecute(); err != nil {
return
}
go func() {
err := (*t).Execute()
err := t.Execute()
if err != nil {
log.Printf("execute manipulation task failed, error = %v", err)
}
(*t).Notify(err)
t.Notify(err)
}()
sched.DmQueue.AddActiveTask(t)
sched.DmQueue.atLock.Unlock()
go func() {
(*t).WaitToFinish()
(*t).PostExecute()
t.WaitToFinish()
t.PostExecute()
// remove from active list
sched.DmQueue.PopActiveTask((*t).EndTs())
sched.DmQueue.PopActiveTask(t.EndTs())
}()
}
}
@ -279,27 +309,27 @@ func (sched *TaskScheduler) queryLoop() {
sched.DqQueue.atLock.Lock()
t := sched.scheduleDqTask()
if err := (*t).PreExecute(); err != nil {
if err := t.PreExecute(); err != nil {
return
}
go func() {
err := (*t).Execute()
err := t.Execute()
if err != nil {
log.Printf("execute query task failed, error = %v", err)
}
(*t).Notify(err)
t.Notify(err)
}()
sched.DqQueue.AddActiveTask(t)
sched.DqQueue.atLock.Unlock()
go func() {
(*t).WaitToFinish()
(*t).PostExecute()
t.WaitToFinish()
t.PostExecute()
// remove from active list
sched.DqQueue.PopActiveTask((*t).EndTs())
sched.DqQueue.PopActiveTask(t.EndTs())
}()
}
}

View File

@ -51,7 +51,6 @@ func newTimeTick(ctx context.Context, tsoAllocator *allocator.TimestampAllocator
return t
}
func (tt *timeTick) tick() error {
if tt.lastTick == tt.currentTick {

View File

@ -33,7 +33,7 @@ func TestTimeTick(t *testing.T) {
tt := timeTick{
interval: 200,
pulsarProducer: producer,
peerID: 1,
peerID: 1,
ctx: ctx,
areRequestsDelivered: func(ts Timestamp) bool { return true },
}