mirror of
https://gitee.com/milvus-io/milvus.git
synced 2025-12-30 15:35:33 +08:00
**What type of PR is this?**
- [x] Feature
**What this PR does / why we need it:**
This PR supports boolean expression as DSL.
1. The goal of this PR is to support predicates
like `A > 3 && not B < 5 or C in [1, 2, 3]`.
2. Defines `plan.proto`, as Intermediate Representation (IR)
used between go and cpp.
3. Support expr parser, convert predicate expr to IR
in proxynode, while doing static check there
4. Support IR to AST in cpp, enable the execution
2668 lines
66 KiB
Go
2668 lines
66 KiB
Go
// Copyright (C) 2019-2020 Zilliz. All rights reserved.
|
|
//
|
|
// Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance
|
|
// with the License. You may obtain a copy of the License at
|
|
//
|
|
// http://www.apache.org/licenses/LICENSE-2.0
|
|
//
|
|
// Unless required by applicable law or agreed to in writing, software distributed under the License
|
|
// is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express
|
|
// or implied. See the License for the specific language governing permissions and limitations under the License.
|
|
|
|
package proxynode
|
|
|
|
import (
|
|
"context"
|
|
"errors"
|
|
"fmt"
|
|
"math"
|
|
"regexp"
|
|
"runtime"
|
|
"strconv"
|
|
"time"
|
|
|
|
"github.com/milvus-io/milvus/internal/proto/planpb"
|
|
|
|
"github.com/milvus-io/milvus/internal/util/funcutil"
|
|
|
|
"go.uber.org/zap"
|
|
|
|
"github.com/golang/protobuf/proto"
|
|
"github.com/milvus-io/milvus/internal/allocator"
|
|
"github.com/milvus-io/milvus/internal/log"
|
|
"github.com/milvus-io/milvus/internal/msgstream"
|
|
"github.com/milvus-io/milvus/internal/proto/commonpb"
|
|
"github.com/milvus-io/milvus/internal/proto/datapb"
|
|
"github.com/milvus-io/milvus/internal/proto/indexpb"
|
|
"github.com/milvus-io/milvus/internal/proto/internalpb"
|
|
"github.com/milvus-io/milvus/internal/proto/milvuspb"
|
|
"github.com/milvus-io/milvus/internal/proto/querypb"
|
|
"github.com/milvus-io/milvus/internal/proto/schemapb"
|
|
"github.com/milvus-io/milvus/internal/types"
|
|
"github.com/milvus-io/milvus/internal/util/typeutil"
|
|
)
|
|
|
|
const (
|
|
InsertTaskName = "InsertTask"
|
|
CreateCollectionTaskName = "CreateCollectionTask"
|
|
DropCollectionTaskName = "DropCollectionTask"
|
|
SearchTaskName = "SearchTask"
|
|
AnnsFieldKey = "anns_field"
|
|
TopKKey = "topk"
|
|
MetricTypeKey = "metric_type"
|
|
SearchParamsKey = "params"
|
|
HasCollectionTaskName = "HasCollectionTask"
|
|
DescribeCollectionTaskName = "DescribeCollectionTask"
|
|
GetCollectionStatisticsTaskName = "GetCollectionStatisticsTask"
|
|
ShowCollectionTaskName = "ShowCollectionTask"
|
|
CreatePartitionTaskName = "CreatePartitionTask"
|
|
DropPartitionTaskName = "DropPartitionTask"
|
|
HasPartitionTaskName = "HasPartitionTask"
|
|
ShowPartitionTaskName = "ShowPartitionTask"
|
|
CreateIndexTaskName = "CreateIndexTask"
|
|
DescribeIndexTaskName = "DescribeIndexTask"
|
|
DropIndexTaskName = "DropIndexTask"
|
|
GetIndexStateTaskName = "GetIndexStateTask"
|
|
GetIndexBuildProgressTaskName = "GetIndexBuildProgressTask"
|
|
FlushTaskName = "FlushTask"
|
|
LoadCollectionTaskName = "LoadCollectionTask"
|
|
ReleaseCollectionTaskName = "ReleaseCollectionTask"
|
|
LoadPartitionTaskName = "LoadPartitionTask"
|
|
ReleasePartitionTaskName = "ReleasePartitionTask"
|
|
)
|
|
|
|
type task interface {
|
|
TraceCtx() context.Context
|
|
ID() UniqueID // return ReqID
|
|
SetID(uid UniqueID) // set ReqID
|
|
Name() string
|
|
Type() commonpb.MsgType
|
|
BeginTs() Timestamp
|
|
EndTs() Timestamp
|
|
SetTs(ts Timestamp)
|
|
OnEnqueue() error
|
|
PreExecute(ctx context.Context) error
|
|
Execute(ctx context.Context) error
|
|
PostExecute(ctx context.Context) error
|
|
WaitToFinish() error
|
|
Notify(err error)
|
|
}
|
|
|
|
type BaseInsertTask = msgstream.InsertMsg
|
|
|
|
type InsertTask struct {
|
|
BaseInsertTask
|
|
Condition
|
|
ctx context.Context
|
|
dataService types.DataService
|
|
result *milvuspb.InsertResponse
|
|
rowIDAllocator *allocator.IDAllocator
|
|
}
|
|
|
|
func (it *InsertTask) TraceCtx() context.Context {
|
|
return it.ctx
|
|
}
|
|
|
|
func (it *InsertTask) ID() UniqueID {
|
|
return it.Base.MsgID
|
|
}
|
|
|
|
func (it *InsertTask) SetID(uid UniqueID) {
|
|
it.Base.MsgID = uid
|
|
}
|
|
|
|
func (it *InsertTask) Name() string {
|
|
return InsertTaskName
|
|
}
|
|
|
|
func (it *InsertTask) Type() commonpb.MsgType {
|
|
return it.Base.MsgType
|
|
}
|
|
|
|
func (it *InsertTask) BeginTs() Timestamp {
|
|
return it.BeginTimestamp
|
|
}
|
|
|
|
func (it *InsertTask) SetTs(ts Timestamp) {
|
|
rowNum := len(it.RowData)
|
|
it.Timestamps = make([]uint64, rowNum)
|
|
for index := range it.Timestamps {
|
|
it.Timestamps[index] = ts
|
|
}
|
|
it.BeginTimestamp = ts
|
|
it.EndTimestamp = ts
|
|
}
|
|
|
|
func (it *InsertTask) EndTs() Timestamp {
|
|
return it.EndTimestamp
|
|
}
|
|
|
|
func (it *InsertTask) OnEnqueue() error {
|
|
it.BaseInsertTask.InsertRequest.Base = &commonpb.MsgBase{}
|
|
return nil
|
|
}
|
|
|
|
func (it *InsertTask) PreExecute(ctx context.Context) error {
|
|
it.Base.MsgType = commonpb.MsgType_Insert
|
|
it.Base.SourceID = Params.ProxyID
|
|
|
|
collectionName := it.BaseInsertTask.CollectionName
|
|
if err := ValidateCollectionName(collectionName); err != nil {
|
|
return err
|
|
}
|
|
partitionTag := it.BaseInsertTask.PartitionName
|
|
if err := ValidatePartitionTag(partitionTag, true); err != nil {
|
|
return err
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
func (it *InsertTask) Execute(ctx context.Context) error {
|
|
collectionName := it.BaseInsertTask.CollectionName
|
|
collSchema, err := globalMetaCache.GetCollectionSchema(ctx, collectionName)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
autoID := collSchema.AutoID
|
|
collID, err := globalMetaCache.GetCollectionID(ctx, collectionName)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
it.CollectionID = collID
|
|
var partitionID UniqueID
|
|
if len(it.PartitionName) > 0 {
|
|
partitionID, err = globalMetaCache.GetPartitionID(ctx, collectionName, it.PartitionName)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
} else {
|
|
partitionID, err = globalMetaCache.GetPartitionID(ctx, collectionName, Params.DefaultPartitionName)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
}
|
|
it.PartitionID = partitionID
|
|
var rowIDBegin UniqueID
|
|
var rowIDEnd UniqueID
|
|
rowNums := len(it.BaseInsertTask.RowData)
|
|
rowIDBegin, rowIDEnd, _ = it.rowIDAllocator.Alloc(uint32(rowNums))
|
|
|
|
it.BaseInsertTask.RowIDs = make([]UniqueID, rowNums)
|
|
for i := rowIDBegin; i < rowIDEnd; i++ {
|
|
offset := i - rowIDBegin
|
|
it.BaseInsertTask.RowIDs[offset] = i
|
|
}
|
|
|
|
if autoID {
|
|
if it.HashValues == nil || len(it.HashValues) == 0 {
|
|
it.HashValues = make([]uint32, 0)
|
|
}
|
|
for _, rowID := range it.RowIDs {
|
|
hashValue, _ := typeutil.Hash32Int64(rowID)
|
|
it.HashValues = append(it.HashValues, hashValue)
|
|
}
|
|
}
|
|
|
|
var tsMsg msgstream.TsMsg = &it.BaseInsertTask
|
|
it.BaseMsg.Ctx = ctx
|
|
msgPack := msgstream.MsgPack{
|
|
BeginTs: it.BeginTs(),
|
|
EndTs: it.EndTs(),
|
|
Msgs: make([]msgstream.TsMsg, 1),
|
|
}
|
|
|
|
it.result = &milvuspb.InsertResponse{
|
|
Status: &commonpb.Status{
|
|
ErrorCode: commonpb.ErrorCode_Success,
|
|
},
|
|
RowIDBegin: rowIDBegin,
|
|
RowIDEnd: rowIDEnd,
|
|
}
|
|
|
|
msgPack.Msgs[0] = tsMsg
|
|
|
|
stream, err := globalInsertChannelsMap.GetInsertMsgStream(collID)
|
|
if err != nil {
|
|
resp, _ := it.dataService.GetInsertChannels(ctx, &datapb.GetInsertChannelsRequest{
|
|
Base: &commonpb.MsgBase{
|
|
MsgType: commonpb.MsgType_Insert, // todo
|
|
MsgID: it.Base.MsgID, // todo
|
|
Timestamp: 0, // todo
|
|
SourceID: Params.ProxyID,
|
|
},
|
|
DbID: 0, // todo
|
|
CollectionID: collID,
|
|
})
|
|
if resp == nil {
|
|
return errors.New("get insert channels resp is nil")
|
|
}
|
|
if resp.Status.ErrorCode != commonpb.ErrorCode_Success {
|
|
return errors.New(resp.Status.Reason)
|
|
}
|
|
err = globalInsertChannelsMap.CreateInsertMsgStream(collID, resp.Values)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
}
|
|
stream, err = globalInsertChannelsMap.GetInsertMsgStream(collID)
|
|
if err != nil {
|
|
it.result.Status.ErrorCode = commonpb.ErrorCode_UnexpectedError
|
|
it.result.Status.Reason = err.Error()
|
|
return err
|
|
}
|
|
|
|
err = stream.Produce(&msgPack)
|
|
if err != nil {
|
|
it.result.Status.ErrorCode = commonpb.ErrorCode_UnexpectedError
|
|
it.result.Status.Reason = err.Error()
|
|
return err
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
func (it *InsertTask) PostExecute(ctx context.Context) error {
|
|
return nil
|
|
}
|
|
|
|
type CreateCollectionTask struct {
|
|
Condition
|
|
*milvuspb.CreateCollectionRequest
|
|
ctx context.Context
|
|
masterService types.MasterService
|
|
dataServiceClient types.DataService
|
|
result *commonpb.Status
|
|
schema *schemapb.CollectionSchema
|
|
}
|
|
|
|
func (cct *CreateCollectionTask) TraceCtx() context.Context {
|
|
return cct.ctx
|
|
}
|
|
|
|
func (cct *CreateCollectionTask) ID() UniqueID {
|
|
return cct.Base.MsgID
|
|
}
|
|
|
|
func (cct *CreateCollectionTask) SetID(uid UniqueID) {
|
|
cct.Base.MsgID = uid
|
|
}
|
|
|
|
func (cct *CreateCollectionTask) Name() string {
|
|
return CreateCollectionTaskName
|
|
}
|
|
|
|
func (cct *CreateCollectionTask) Type() commonpb.MsgType {
|
|
return cct.Base.MsgType
|
|
}
|
|
|
|
func (cct *CreateCollectionTask) BeginTs() Timestamp {
|
|
return cct.Base.Timestamp
|
|
}
|
|
|
|
func (cct *CreateCollectionTask) EndTs() Timestamp {
|
|
return cct.Base.Timestamp
|
|
}
|
|
|
|
func (cct *CreateCollectionTask) SetTs(ts Timestamp) {
|
|
cct.Base.Timestamp = ts
|
|
}
|
|
|
|
func (cct *CreateCollectionTask) OnEnqueue() error {
|
|
cct.Base = &commonpb.MsgBase{}
|
|
return nil
|
|
}
|
|
|
|
func (cct *CreateCollectionTask) PreExecute(ctx context.Context) error {
|
|
cct.Base.MsgType = commonpb.MsgType_CreateCollection
|
|
cct.Base.SourceID = Params.ProxyID
|
|
|
|
cct.schema = &schemapb.CollectionSchema{}
|
|
err := proto.Unmarshal(cct.Schema, cct.schema)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
if int64(len(cct.schema.Fields)) > Params.MaxFieldNum {
|
|
return fmt.Errorf("maximum field's number should be limited to %d", Params.MaxFieldNum)
|
|
}
|
|
|
|
// validate collection name
|
|
if err := ValidateCollectionName(cct.schema.Name); err != nil {
|
|
return err
|
|
}
|
|
|
|
if err := ValidateDuplicatedFieldName(cct.schema.Fields); err != nil {
|
|
return err
|
|
}
|
|
|
|
if err := ValidatePrimaryKey(cct.schema); err != nil {
|
|
return err
|
|
}
|
|
|
|
// validate field name
|
|
for _, field := range cct.schema.Fields {
|
|
if err := ValidateFieldName(field.Name); err != nil {
|
|
return err
|
|
}
|
|
if field.DataType == schemapb.DataType_FloatVector || field.DataType == schemapb.DataType_BinaryVector {
|
|
exist := false
|
|
var dim int64 = 0
|
|
for _, param := range field.TypeParams {
|
|
if param.Key == "dim" {
|
|
exist = true
|
|
tmp, err := strconv.ParseInt(param.Value, 10, 64)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
dim = tmp
|
|
break
|
|
}
|
|
}
|
|
if !exist {
|
|
return errors.New("dimension is not defined in field type params")
|
|
}
|
|
if field.DataType == schemapb.DataType_FloatVector {
|
|
if err := ValidateDimension(dim, false); err != nil {
|
|
return err
|
|
}
|
|
} else {
|
|
if err := ValidateDimension(dim, true); err != nil {
|
|
return err
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
func (cct *CreateCollectionTask) Execute(ctx context.Context) error {
|
|
var err error
|
|
cct.result, err = cct.masterService.CreateCollection(ctx, cct.CreateCollectionRequest)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
if cct.result.ErrorCode == commonpb.ErrorCode_Success {
|
|
collID, err := globalMetaCache.GetCollectionID(ctx, cct.CollectionName)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
resp, _ := cct.dataServiceClient.GetInsertChannels(ctx, &datapb.GetInsertChannelsRequest{
|
|
Base: &commonpb.MsgBase{
|
|
MsgType: commonpb.MsgType_Insert, // todo
|
|
MsgID: cct.Base.MsgID, // todo
|
|
Timestamp: 0, // todo
|
|
SourceID: Params.ProxyID,
|
|
},
|
|
DbID: 0, // todo
|
|
CollectionID: collID,
|
|
})
|
|
if resp == nil {
|
|
return errors.New("get insert channels resp is nil")
|
|
}
|
|
if resp.Status.ErrorCode != commonpb.ErrorCode_Success {
|
|
return errors.New(resp.Status.Reason)
|
|
}
|
|
err = globalInsertChannelsMap.CreateInsertMsgStream(collID, resp.Values)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
}
|
|
return nil
|
|
}
|
|
|
|
func (cct *CreateCollectionTask) PostExecute(ctx context.Context) error {
|
|
return nil
|
|
}
|
|
|
|
type DropCollectionTask struct {
|
|
Condition
|
|
*milvuspb.DropCollectionRequest
|
|
ctx context.Context
|
|
masterService types.MasterService
|
|
result *commonpb.Status
|
|
}
|
|
|
|
func (dct *DropCollectionTask) TraceCtx() context.Context {
|
|
return dct.ctx
|
|
}
|
|
|
|
func (dct *DropCollectionTask) ID() UniqueID {
|
|
return dct.Base.MsgID
|
|
}
|
|
|
|
func (dct *DropCollectionTask) SetID(uid UniqueID) {
|
|
dct.Base.MsgID = uid
|
|
}
|
|
|
|
func (dct *DropCollectionTask) Name() string {
|
|
return DropCollectionTaskName
|
|
}
|
|
|
|
func (dct *DropCollectionTask) Type() commonpb.MsgType {
|
|
return dct.Base.MsgType
|
|
}
|
|
|
|
func (dct *DropCollectionTask) BeginTs() Timestamp {
|
|
return dct.Base.Timestamp
|
|
}
|
|
|
|
func (dct *DropCollectionTask) EndTs() Timestamp {
|
|
return dct.Base.Timestamp
|
|
}
|
|
|
|
func (dct *DropCollectionTask) SetTs(ts Timestamp) {
|
|
dct.Base.Timestamp = ts
|
|
}
|
|
|
|
func (dct *DropCollectionTask) OnEnqueue() error {
|
|
dct.Base = &commonpb.MsgBase{}
|
|
return nil
|
|
}
|
|
|
|
func (dct *DropCollectionTask) PreExecute(ctx context.Context) error {
|
|
dct.Base.MsgType = commonpb.MsgType_DropCollection
|
|
dct.Base.SourceID = Params.ProxyID
|
|
|
|
if err := ValidateCollectionName(dct.CollectionName); err != nil {
|
|
return err
|
|
}
|
|
return nil
|
|
}
|
|
|
|
func (dct *DropCollectionTask) Execute(ctx context.Context) error {
|
|
collID, err := globalMetaCache.GetCollectionID(ctx, dct.CollectionName)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
dct.result, err = dct.masterService.DropCollection(ctx, dct.DropCollectionRequest)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
err = globalInsertChannelsMap.CloseInsertMsgStream(collID)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
func (dct *DropCollectionTask) PostExecute(ctx context.Context) error {
|
|
globalMetaCache.RemoveCollection(ctx, dct.CollectionName)
|
|
return nil
|
|
}
|
|
|
|
type SearchTask struct {
|
|
Condition
|
|
*internalpb.SearchRequest
|
|
ctx context.Context
|
|
queryMsgStream msgstream.MsgStream
|
|
resultBuf chan []*internalpb.SearchResults
|
|
result *milvuspb.SearchResults
|
|
query *milvuspb.SearchRequest
|
|
}
|
|
|
|
func (st *SearchTask) TraceCtx() context.Context {
|
|
return st.ctx
|
|
}
|
|
|
|
func (st *SearchTask) ID() UniqueID {
|
|
return st.Base.MsgID
|
|
}
|
|
|
|
func (st *SearchTask) SetID(uid UniqueID) {
|
|
st.Base.MsgID = uid
|
|
}
|
|
|
|
func (st *SearchTask) Name() string {
|
|
return SearchTaskName
|
|
}
|
|
|
|
func (st *SearchTask) Type() commonpb.MsgType {
|
|
return st.Base.MsgType
|
|
}
|
|
|
|
func (st *SearchTask) BeginTs() Timestamp {
|
|
return st.Base.Timestamp
|
|
}
|
|
|
|
func (st *SearchTask) EndTs() Timestamp {
|
|
return st.Base.Timestamp
|
|
}
|
|
|
|
func (st *SearchTask) SetTs(ts Timestamp) {
|
|
st.Base.Timestamp = ts
|
|
}
|
|
|
|
func (st *SearchTask) OnEnqueue() error {
|
|
st.Base = &commonpb.MsgBase{}
|
|
return nil
|
|
}
|
|
|
|
func (st *SearchTask) PreExecute(ctx context.Context) error {
|
|
st.Base.MsgType = commonpb.MsgType_Search
|
|
st.Base.SourceID = Params.ProxyID
|
|
|
|
collectionName := st.query.CollectionName
|
|
_, err := globalMetaCache.GetCollectionID(ctx, collectionName)
|
|
if err != nil { // err is not nil if collection not exists
|
|
return err
|
|
}
|
|
|
|
if err := ValidateCollectionName(st.query.CollectionName); err != nil {
|
|
return err
|
|
}
|
|
|
|
for _, tag := range st.query.PartitionNames {
|
|
if err := ValidatePartitionTag(tag, false); err != nil {
|
|
return err
|
|
}
|
|
}
|
|
st.Base.MsgType = commonpb.MsgType_Search
|
|
|
|
var dsl string
|
|
dsl = st.query.Dsl
|
|
if st.query.GetDslType() == commonpb.DslType_BoolExprV1 {
|
|
schema, err := globalMetaCache.GetCollectionSchema(ctx, collectionName)
|
|
if err != nil { // err is not nil if collection not exists
|
|
return err
|
|
}
|
|
|
|
annsField, err := GetAttrByKeyFromRepeatedKV(AnnsFieldKey, st.query.SearchParams)
|
|
if err != nil {
|
|
return errors.New(AnnsFieldKey + " not found in search_params")
|
|
}
|
|
|
|
topKStr, err := GetAttrByKeyFromRepeatedKV(TopKKey, st.query.SearchParams)
|
|
if err != nil {
|
|
return errors.New(TopKKey + " not found in search_params")
|
|
}
|
|
topK, err := strconv.Atoi(topKStr)
|
|
if err != nil {
|
|
return errors.New(TopKKey + " " + topKStr + " is not invalid")
|
|
}
|
|
|
|
metricType, err := GetAttrByKeyFromRepeatedKV(MetricTypeKey, st.query.SearchParams)
|
|
if err != nil {
|
|
return errors.New(MetricTypeKey + " not found in search_params")
|
|
}
|
|
|
|
searchParams, err := GetAttrByKeyFromRepeatedKV(SearchParamsKey, st.query.SearchParams)
|
|
if err != nil {
|
|
return errors.New(SearchParamsKey + " not found in search_params")
|
|
}
|
|
|
|
queryInfo := &planpb.QueryInfo{
|
|
Topk: int64(topK),
|
|
MetricType: metricType,
|
|
SearchParams: searchParams,
|
|
}
|
|
|
|
plan, err := CreateQueryPlan(schema, &st.query.Dsl, annsField, queryInfo)
|
|
if err != nil {
|
|
return errors.New("invalid expression: " + st.query.Dsl)
|
|
}
|
|
|
|
dsl = proto.MarshalTextString(plan)
|
|
st.query.Dsl = dsl
|
|
}
|
|
|
|
queryBytes, err := proto.Marshal(st.query)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
st.Query = &commonpb.Blob{
|
|
Value: queryBytes,
|
|
}
|
|
|
|
st.ResultChannelID = Params.SearchResultChannelNames[0]
|
|
st.DbID = 0 // todo
|
|
collectionID, err := globalMetaCache.GetCollectionID(ctx, collectionName)
|
|
if err != nil { // err is not nil if collection not exists
|
|
return err
|
|
}
|
|
st.CollectionID = collectionID
|
|
st.PartitionIDs = make([]UniqueID, 0)
|
|
|
|
partitionsMap, err := globalMetaCache.GetPartitions(ctx, collectionName)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
partitionsRecord := make(map[UniqueID]bool)
|
|
for _, partitionName := range st.query.PartitionNames {
|
|
pattern := fmt.Sprintf("^%s$", partitionName)
|
|
re, err := regexp.Compile(pattern)
|
|
if err != nil {
|
|
return errors.New("invalid partition names")
|
|
}
|
|
found := false
|
|
for name, pID := range partitionsMap {
|
|
if re.MatchString(name) {
|
|
if _, exist := partitionsRecord[pID]; !exist {
|
|
st.PartitionIDs = append(st.PartitionIDs, pID)
|
|
partitionsRecord[pID] = true
|
|
}
|
|
found = true
|
|
}
|
|
}
|
|
if !found {
|
|
errMsg := fmt.Sprintf("PartitonName: %s not found", partitionName)
|
|
return errors.New(errMsg)
|
|
}
|
|
}
|
|
|
|
st.Dsl = dsl
|
|
st.PlaceholderGroup = st.query.PlaceholderGroup
|
|
|
|
return nil
|
|
}
|
|
|
|
func (st *SearchTask) Execute(ctx context.Context) error {
|
|
var tsMsg msgstream.TsMsg = &msgstream.SearchMsg{
|
|
SearchRequest: *st.SearchRequest,
|
|
BaseMsg: msgstream.BaseMsg{
|
|
Ctx: ctx,
|
|
HashValues: []uint32{uint32(Params.ProxyID)},
|
|
BeginTimestamp: st.Base.Timestamp,
|
|
EndTimestamp: st.Base.Timestamp,
|
|
},
|
|
}
|
|
msgPack := msgstream.MsgPack{
|
|
BeginTs: st.Base.Timestamp,
|
|
EndTs: st.Base.Timestamp,
|
|
Msgs: make([]msgstream.TsMsg, 1),
|
|
}
|
|
msgPack.Msgs[0] = tsMsg
|
|
err := st.queryMsgStream.Produce(&msgPack)
|
|
log.Debug("proxynode", zap.Int("length of searchMsg", len(msgPack.Msgs)))
|
|
if err != nil {
|
|
log.Debug("proxynode", zap.String("send search request failed", err.Error()))
|
|
}
|
|
return err
|
|
}
|
|
|
|
// TODO: add benchmark to compare with serial implementation
|
|
func decodeSearchResultsParallel(searchResults []*internalpb.SearchResults, maxParallel int) ([][]*milvuspb.Hits, error) {
|
|
log.Debug("decodeSearchResultsParallel", zap.Any("NumOfGoRoutines", maxParallel))
|
|
|
|
hits := make([][]*milvuspb.Hits, 0)
|
|
// necessary to parallel this?
|
|
for _, partialSearchResult := range searchResults {
|
|
if partialSearchResult.Hits == nil || len(partialSearchResult.Hits) <= 0 {
|
|
continue
|
|
}
|
|
|
|
nq := len(partialSearchResult.Hits)
|
|
partialHits := make([]*milvuspb.Hits, nq)
|
|
|
|
f := func(idx int) error {
|
|
partialHit := &milvuspb.Hits{}
|
|
|
|
err := proto.Unmarshal(partialSearchResult.Hits[idx], partialHit)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
partialHits[idx] = partialHit
|
|
|
|
return nil
|
|
}
|
|
|
|
err := funcutil.ProcessFuncParallel(nq, maxParallel, f, "decodePartialSearchResult")
|
|
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
hits = append(hits, partialHits)
|
|
}
|
|
|
|
return hits, nil
|
|
}
|
|
|
|
func decodeSearchResultsSerial(searchResults []*internalpb.SearchResults) ([][]*milvuspb.Hits, error) {
|
|
return decodeSearchResultsParallel(searchResults, 1)
|
|
}
|
|
|
|
// TODO: add benchmark to compare with serial implementation
|
|
func decodeSearchResultsParallelByNq(searchResults []*internalpb.SearchResults) ([][]*milvuspb.Hits, error) {
|
|
if len(searchResults) <= 0 {
|
|
return nil, errors.New("no need to decode empty search results")
|
|
}
|
|
nq := len(searchResults[0].Hits)
|
|
return decodeSearchResultsParallel(searchResults, nq)
|
|
}
|
|
|
|
// TODO: add benchmark to compare with serial implementation
|
|
func decodeSearchResultsParallelByCPU(searchResults []*internalpb.SearchResults) ([][]*milvuspb.Hits, error) {
|
|
return decodeSearchResultsParallel(searchResults, runtime.NumCPU())
|
|
}
|
|
|
|
func decodeSearchResults(searchResults []*internalpb.SearchResults) ([][]*milvuspb.Hits, error) {
|
|
t := time.Now()
|
|
defer func() {
|
|
log.Debug("decodeSearchResults", zap.Any("time cost", time.Since(t)))
|
|
}()
|
|
return decodeSearchResultsParallelByCPU(searchResults)
|
|
}
|
|
|
|
func reduceSearchResultsParallel(hits [][]*milvuspb.Hits, nq, availableQueryNodeNum, topk int, metricType string, maxParallel int) *milvuspb.SearchResults {
|
|
log.Debug("reduceSearchResultsParallel", zap.Any("NumOfGoRoutines", maxParallel))
|
|
|
|
ret := &milvuspb.SearchResults{
|
|
Status: &commonpb.Status{
|
|
ErrorCode: 0,
|
|
},
|
|
Hits: make([][]byte, nq),
|
|
}
|
|
|
|
const minFloat32 = -1 * float32(math.MaxFloat32)
|
|
|
|
f := func(idx int) error {
|
|
locs := make([]int, availableQueryNodeNum)
|
|
reducedHits := &milvuspb.Hits{
|
|
IDs: make([]int64, 0),
|
|
RowData: make([][]byte, 0),
|
|
Scores: make([]float32, 0),
|
|
}
|
|
|
|
for j := 0; j < topk; j++ {
|
|
valid := false
|
|
choice, maxDistance := 0, minFloat32
|
|
for q, loc := range locs { // query num, the number of ways to merge
|
|
if loc >= len(hits[q][idx].IDs) {
|
|
continue
|
|
}
|
|
distance := hits[q][idx].Scores[loc]
|
|
if distance > maxDistance || (math.Abs(float64(distance-maxDistance)) < math.SmallestNonzeroFloat32 && choice != q) {
|
|
choice = q
|
|
maxDistance = distance
|
|
valid = true
|
|
}
|
|
}
|
|
if !valid {
|
|
break
|
|
}
|
|
choiceOffset := locs[choice]
|
|
// check if distance is valid, `invalid` here means very very big,
|
|
// in this process, distance here is the smallest, so the rest of distance are all invalid
|
|
if hits[choice][idx].Scores[choiceOffset] <= minFloat32 {
|
|
break
|
|
}
|
|
reducedHits.IDs = append(reducedHits.IDs, hits[choice][idx].IDs[choiceOffset])
|
|
if hits[choice][idx].RowData != nil && len(hits[choice][idx].RowData) > 0 {
|
|
reducedHits.RowData = append(reducedHits.RowData, hits[choice][idx].RowData[choiceOffset])
|
|
}
|
|
reducedHits.Scores = append(reducedHits.Scores, hits[choice][idx].Scores[choiceOffset])
|
|
locs[choice]++
|
|
}
|
|
|
|
if metricType != "IP" {
|
|
for k := range reducedHits.Scores {
|
|
reducedHits.Scores[k] *= -1
|
|
}
|
|
}
|
|
|
|
reducedHitsBs, err := proto.Marshal(reducedHits)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
ret.Hits[idx] = reducedHitsBs
|
|
|
|
return nil
|
|
}
|
|
|
|
err := funcutil.ProcessFuncParallel(nq, maxParallel, f, "reduceSearchResults")
|
|
if err != nil {
|
|
return nil
|
|
}
|
|
|
|
return ret
|
|
}
|
|
|
|
func reduceSearchResultsSerial(hits [][]*milvuspb.Hits, nq, availableQueryNodeNum, topk int, metricType string) *milvuspb.SearchResults {
|
|
return reduceSearchResultsParallel(hits, nq, availableQueryNodeNum, topk, metricType, 1)
|
|
}
|
|
|
|
// TODO: add benchmark to compare with serial implementation
|
|
func reduceSearchResultsParallelByNq(hits [][]*milvuspb.Hits, nq, availableQueryNodeNum, topk int, metricType string) *milvuspb.SearchResults {
|
|
return reduceSearchResultsParallel(hits, nq, availableQueryNodeNum, topk, metricType, nq)
|
|
}
|
|
|
|
// TODO: add benchmark to compare with serial implementation
|
|
func reduceSearchResultsParallelByCPU(hits [][]*milvuspb.Hits, nq, availableQueryNodeNum, topk int, metricType string) *milvuspb.SearchResults {
|
|
return reduceSearchResultsParallel(hits, nq, availableQueryNodeNum, topk, metricType, runtime.NumCPU())
|
|
}
|
|
|
|
func reduceSearchResults(hits [][]*milvuspb.Hits, nq, availableQueryNodeNum, topk int, metricType string) *milvuspb.SearchResults {
|
|
t := time.Now()
|
|
defer func() {
|
|
log.Debug("reduceSearchResults", zap.Any("time cost", time.Since(t)))
|
|
}()
|
|
return reduceSearchResultsParallelByCPU(hits, nq, availableQueryNodeNum, topk, metricType)
|
|
}
|
|
|
|
func printSearchResult(partialSearchResult *internalpb.SearchResults) {
|
|
for i := 0; i < len(partialSearchResult.Hits); i++ {
|
|
testHits := milvuspb.Hits{}
|
|
err := proto.Unmarshal(partialSearchResult.Hits[i], &testHits)
|
|
if err != nil {
|
|
panic(err)
|
|
}
|
|
fmt.Println(testHits.IDs)
|
|
fmt.Println(testHits.Scores)
|
|
}
|
|
}
|
|
|
|
func (st *SearchTask) PostExecute(ctx context.Context) error {
|
|
t0 := time.Now()
|
|
defer func() {
|
|
log.Debug("WaitAndPostExecute", zap.Any("time cost", time.Since(t0)))
|
|
}()
|
|
for {
|
|
select {
|
|
case <-st.TraceCtx().Done():
|
|
log.Debug("proxynode", zap.Int64("SearchTask: wait to finish failed, timeout!, taskID:", st.ID()))
|
|
return fmt.Errorf("SearchTask:wait to finish failed, timeout: %d", st.ID())
|
|
case searchResults := <-st.resultBuf:
|
|
// fmt.Println("searchResults: ", searchResults)
|
|
filterSearchResult := make([]*internalpb.SearchResults, 0)
|
|
var filterReason string
|
|
for _, partialSearchResult := range searchResults {
|
|
if partialSearchResult.Status.ErrorCode == commonpb.ErrorCode_Success {
|
|
filterSearchResult = append(filterSearchResult, partialSearchResult)
|
|
// For debugging, please don't delete.
|
|
// printSearchResult(partialSearchResult)
|
|
} else {
|
|
filterReason += partialSearchResult.Status.Reason + "\n"
|
|
}
|
|
}
|
|
|
|
availableQueryNodeNum := len(filterSearchResult)
|
|
if availableQueryNodeNum <= 0 {
|
|
st.result = &milvuspb.SearchResults{
|
|
Status: &commonpb.Status{
|
|
ErrorCode: commonpb.ErrorCode_UnexpectedError,
|
|
Reason: filterReason,
|
|
},
|
|
}
|
|
return errors.New(filterReason)
|
|
}
|
|
|
|
availableQueryNodeNum = 0
|
|
for _, partialSearchResult := range filterSearchResult {
|
|
if partialSearchResult.Hits == nil || len(partialSearchResult.Hits) <= 0 {
|
|
filterReason += "nq is zero\n"
|
|
continue
|
|
}
|
|
availableQueryNodeNum++
|
|
}
|
|
|
|
if availableQueryNodeNum <= 0 {
|
|
st.result = &milvuspb.SearchResults{
|
|
Status: &commonpb.Status{
|
|
ErrorCode: commonpb.ErrorCode_Success,
|
|
Reason: filterReason,
|
|
},
|
|
}
|
|
return nil
|
|
}
|
|
|
|
hits, err := decodeSearchResults(filterSearchResult)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
nq := len(hits[0])
|
|
if nq <= 0 {
|
|
st.result = &milvuspb.SearchResults{
|
|
Status: &commonpb.Status{
|
|
ErrorCode: commonpb.ErrorCode_Success,
|
|
Reason: filterReason,
|
|
},
|
|
}
|
|
return nil
|
|
}
|
|
|
|
topk := 0
|
|
for _, hit := range hits {
|
|
topk = getMax(topk, len(hit[0].IDs))
|
|
}
|
|
|
|
st.result = reduceSearchResults(hits, nq, availableQueryNodeNum, topk, searchResults[0].MetricType)
|
|
|
|
return nil
|
|
}
|
|
}
|
|
}
|
|
|
|
type HasCollectionTask struct {
|
|
Condition
|
|
*milvuspb.HasCollectionRequest
|
|
ctx context.Context
|
|
masterService types.MasterService
|
|
result *milvuspb.BoolResponse
|
|
}
|
|
|
|
func (hct *HasCollectionTask) TraceCtx() context.Context {
|
|
return hct.ctx
|
|
}
|
|
|
|
func (hct *HasCollectionTask) ID() UniqueID {
|
|
return hct.Base.MsgID
|
|
}
|
|
|
|
func (hct *HasCollectionTask) SetID(uid UniqueID) {
|
|
hct.Base.MsgID = uid
|
|
}
|
|
|
|
func (hct *HasCollectionTask) Name() string {
|
|
return HasCollectionTaskName
|
|
}
|
|
|
|
func (hct *HasCollectionTask) Type() commonpb.MsgType {
|
|
return hct.Base.MsgType
|
|
}
|
|
|
|
func (hct *HasCollectionTask) BeginTs() Timestamp {
|
|
return hct.Base.Timestamp
|
|
}
|
|
|
|
func (hct *HasCollectionTask) EndTs() Timestamp {
|
|
return hct.Base.Timestamp
|
|
}
|
|
|
|
func (hct *HasCollectionTask) SetTs(ts Timestamp) {
|
|
hct.Base.Timestamp = ts
|
|
}
|
|
|
|
func (hct *HasCollectionTask) OnEnqueue() error {
|
|
hct.Base = &commonpb.MsgBase{}
|
|
return nil
|
|
}
|
|
|
|
func (hct *HasCollectionTask) PreExecute(ctx context.Context) error {
|
|
hct.Base.MsgType = commonpb.MsgType_HasCollection
|
|
hct.Base.SourceID = Params.ProxyID
|
|
|
|
if err := ValidateCollectionName(hct.CollectionName); err != nil {
|
|
return err
|
|
}
|
|
return nil
|
|
}
|
|
|
|
func (hct *HasCollectionTask) Execute(ctx context.Context) error {
|
|
var err error
|
|
hct.result, err = hct.masterService.HasCollection(ctx, hct.HasCollectionRequest)
|
|
if hct.result == nil {
|
|
return errors.New("has collection resp is nil")
|
|
}
|
|
if hct.result.Status.ErrorCode != commonpb.ErrorCode_Success {
|
|
return errors.New(hct.result.Status.Reason)
|
|
}
|
|
return err
|
|
}
|
|
|
|
func (hct *HasCollectionTask) PostExecute(ctx context.Context) error {
|
|
return nil
|
|
}
|
|
|
|
type DescribeCollectionTask struct {
|
|
Condition
|
|
*milvuspb.DescribeCollectionRequest
|
|
ctx context.Context
|
|
masterService types.MasterService
|
|
result *milvuspb.DescribeCollectionResponse
|
|
}
|
|
|
|
func (dct *DescribeCollectionTask) TraceCtx() context.Context {
|
|
return dct.ctx
|
|
}
|
|
|
|
func (dct *DescribeCollectionTask) ID() UniqueID {
|
|
return dct.Base.MsgID
|
|
}
|
|
|
|
func (dct *DescribeCollectionTask) SetID(uid UniqueID) {
|
|
dct.Base.MsgID = uid
|
|
}
|
|
|
|
func (dct *DescribeCollectionTask) Name() string {
|
|
return DescribeCollectionTaskName
|
|
}
|
|
|
|
func (dct *DescribeCollectionTask) Type() commonpb.MsgType {
|
|
return dct.Base.MsgType
|
|
}
|
|
|
|
func (dct *DescribeCollectionTask) BeginTs() Timestamp {
|
|
return dct.Base.Timestamp
|
|
}
|
|
|
|
func (dct *DescribeCollectionTask) EndTs() Timestamp {
|
|
return dct.Base.Timestamp
|
|
}
|
|
|
|
func (dct *DescribeCollectionTask) SetTs(ts Timestamp) {
|
|
dct.Base.Timestamp = ts
|
|
}
|
|
|
|
func (dct *DescribeCollectionTask) OnEnqueue() error {
|
|
dct.Base = &commonpb.MsgBase{}
|
|
return nil
|
|
}
|
|
|
|
func (dct *DescribeCollectionTask) PreExecute(ctx context.Context) error {
|
|
dct.Base.MsgType = commonpb.MsgType_DescribeCollection
|
|
dct.Base.SourceID = Params.ProxyID
|
|
|
|
if err := ValidateCollectionName(dct.CollectionName); err != nil {
|
|
return err
|
|
}
|
|
return nil
|
|
}
|
|
|
|
func (dct *DescribeCollectionTask) Execute(ctx context.Context) error {
|
|
var err error
|
|
dct.result, err = dct.masterService.DescribeCollection(ctx, dct.DescribeCollectionRequest)
|
|
if dct.result == nil {
|
|
return errors.New("has collection resp is nil")
|
|
}
|
|
if dct.result.Status.ErrorCode != commonpb.ErrorCode_Success {
|
|
return errors.New(dct.result.Status.Reason)
|
|
}
|
|
return err
|
|
}
|
|
|
|
func (dct *DescribeCollectionTask) PostExecute(ctx context.Context) error {
|
|
return nil
|
|
}
|
|
|
|
type GetCollectionsStatisticsTask struct {
|
|
Condition
|
|
*milvuspb.GetCollectionStatisticsRequest
|
|
ctx context.Context
|
|
dataService types.DataService
|
|
result *milvuspb.GetCollectionStatisticsResponse
|
|
}
|
|
|
|
func (g *GetCollectionsStatisticsTask) TraceCtx() context.Context {
|
|
return g.ctx
|
|
}
|
|
|
|
func (g *GetCollectionsStatisticsTask) ID() UniqueID {
|
|
return g.Base.MsgID
|
|
}
|
|
|
|
func (g *GetCollectionsStatisticsTask) SetID(uid UniqueID) {
|
|
g.Base.MsgID = uid
|
|
}
|
|
|
|
func (g *GetCollectionsStatisticsTask) Name() string {
|
|
return GetCollectionStatisticsTaskName
|
|
}
|
|
|
|
func (g *GetCollectionsStatisticsTask) Type() commonpb.MsgType {
|
|
return g.Base.MsgType
|
|
}
|
|
|
|
func (g *GetCollectionsStatisticsTask) BeginTs() Timestamp {
|
|
return g.Base.Timestamp
|
|
}
|
|
|
|
func (g *GetCollectionsStatisticsTask) EndTs() Timestamp {
|
|
return g.Base.Timestamp
|
|
}
|
|
|
|
func (g *GetCollectionsStatisticsTask) SetTs(ts Timestamp) {
|
|
g.Base.Timestamp = ts
|
|
}
|
|
|
|
func (g *GetCollectionsStatisticsTask) OnEnqueue() error {
|
|
g.Base = &commonpb.MsgBase{}
|
|
return nil
|
|
}
|
|
|
|
func (g *GetCollectionsStatisticsTask) PreExecute(ctx context.Context) error {
|
|
g.Base.MsgType = commonpb.MsgType_GetCollectionStatistics
|
|
g.Base.SourceID = Params.ProxyID
|
|
return nil
|
|
}
|
|
|
|
func (g *GetCollectionsStatisticsTask) Execute(ctx context.Context) error {
|
|
collID, err := globalMetaCache.GetCollectionID(ctx, g.CollectionName)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
req := &datapb.GetCollectionStatisticsRequest{
|
|
Base: &commonpb.MsgBase{
|
|
MsgType: commonpb.MsgType_GetCollectionStatistics,
|
|
MsgID: g.Base.MsgID,
|
|
Timestamp: g.Base.Timestamp,
|
|
SourceID: g.Base.SourceID,
|
|
},
|
|
CollectionID: collID,
|
|
}
|
|
|
|
result, _ := g.dataService.GetCollectionStatistics(ctx, req)
|
|
if result == nil {
|
|
return errors.New("get collection statistics resp is nil")
|
|
}
|
|
if result.Status.ErrorCode != commonpb.ErrorCode_Success {
|
|
return errors.New(result.Status.Reason)
|
|
}
|
|
g.result = &milvuspb.GetCollectionStatisticsResponse{
|
|
Status: &commonpb.Status{
|
|
ErrorCode: commonpb.ErrorCode_Success,
|
|
Reason: "",
|
|
},
|
|
Stats: result.Stats,
|
|
}
|
|
return nil
|
|
}
|
|
|
|
func (g *GetCollectionsStatisticsTask) PostExecute(ctx context.Context) error {
|
|
return nil
|
|
}
|
|
|
|
type ShowCollectionsTask struct {
|
|
Condition
|
|
*milvuspb.ShowCollectionsRequest
|
|
ctx context.Context
|
|
masterService types.MasterService
|
|
result *milvuspb.ShowCollectionsResponse
|
|
}
|
|
|
|
func (sct *ShowCollectionsTask) TraceCtx() context.Context {
|
|
return sct.ctx
|
|
}
|
|
|
|
func (sct *ShowCollectionsTask) ID() UniqueID {
|
|
return sct.Base.MsgID
|
|
}
|
|
|
|
func (sct *ShowCollectionsTask) SetID(uid UniqueID) {
|
|
sct.Base.MsgID = uid
|
|
}
|
|
|
|
func (sct *ShowCollectionsTask) Name() string {
|
|
return ShowCollectionTaskName
|
|
}
|
|
|
|
func (sct *ShowCollectionsTask) Type() commonpb.MsgType {
|
|
return sct.Base.MsgType
|
|
}
|
|
|
|
func (sct *ShowCollectionsTask) BeginTs() Timestamp {
|
|
return sct.Base.Timestamp
|
|
}
|
|
|
|
func (sct *ShowCollectionsTask) EndTs() Timestamp {
|
|
return sct.Base.Timestamp
|
|
}
|
|
|
|
func (sct *ShowCollectionsTask) SetTs(ts Timestamp) {
|
|
sct.Base.Timestamp = ts
|
|
}
|
|
|
|
func (sct *ShowCollectionsTask) OnEnqueue() error {
|
|
sct.Base = &commonpb.MsgBase{}
|
|
return nil
|
|
}
|
|
|
|
func (sct *ShowCollectionsTask) PreExecute(ctx context.Context) error {
|
|
sct.Base.MsgType = commonpb.MsgType_ShowCollections
|
|
sct.Base.SourceID = Params.ProxyID
|
|
|
|
return nil
|
|
}
|
|
|
|
func (sct *ShowCollectionsTask) Execute(ctx context.Context) error {
|
|
var err error
|
|
sct.result, err = sct.masterService.ShowCollections(ctx, sct.ShowCollectionsRequest)
|
|
if sct.result == nil {
|
|
return errors.New("get collection statistics resp is nil")
|
|
}
|
|
if sct.result.Status.ErrorCode != commonpb.ErrorCode_Success {
|
|
return errors.New(sct.result.Status.Reason)
|
|
}
|
|
return err
|
|
}
|
|
|
|
func (sct *ShowCollectionsTask) PostExecute(ctx context.Context) error {
|
|
return nil
|
|
}
|
|
|
|
type CreatePartitionTask struct {
|
|
Condition
|
|
*milvuspb.CreatePartitionRequest
|
|
ctx context.Context
|
|
masterService types.MasterService
|
|
result *commonpb.Status
|
|
}
|
|
|
|
func (cpt *CreatePartitionTask) TraceCtx() context.Context {
|
|
return cpt.ctx
|
|
}
|
|
|
|
func (cpt *CreatePartitionTask) ID() UniqueID {
|
|
return cpt.Base.MsgID
|
|
}
|
|
|
|
func (cpt *CreatePartitionTask) SetID(uid UniqueID) {
|
|
cpt.Base.MsgID = uid
|
|
}
|
|
|
|
func (cpt *CreatePartitionTask) Name() string {
|
|
return CreatePartitionTaskName
|
|
}
|
|
|
|
func (cpt *CreatePartitionTask) Type() commonpb.MsgType {
|
|
return cpt.Base.MsgType
|
|
}
|
|
|
|
func (cpt *CreatePartitionTask) BeginTs() Timestamp {
|
|
return cpt.Base.Timestamp
|
|
}
|
|
|
|
func (cpt *CreatePartitionTask) EndTs() Timestamp {
|
|
return cpt.Base.Timestamp
|
|
}
|
|
|
|
func (cpt *CreatePartitionTask) SetTs(ts Timestamp) {
|
|
cpt.Base.Timestamp = ts
|
|
}
|
|
|
|
func (cpt *CreatePartitionTask) OnEnqueue() error {
|
|
cpt.Base = &commonpb.MsgBase{}
|
|
return nil
|
|
}
|
|
|
|
func (cpt *CreatePartitionTask) PreExecute(ctx context.Context) error {
|
|
cpt.Base.MsgType = commonpb.MsgType_CreatePartition
|
|
cpt.Base.SourceID = Params.ProxyID
|
|
|
|
collName, partitionTag := cpt.CollectionName, cpt.PartitionName
|
|
|
|
if err := ValidateCollectionName(collName); err != nil {
|
|
return err
|
|
}
|
|
|
|
if err := ValidatePartitionTag(partitionTag, true); err != nil {
|
|
return err
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
func (cpt *CreatePartitionTask) Execute(ctx context.Context) (err error) {
|
|
cpt.result, err = cpt.masterService.CreatePartition(ctx, cpt.CreatePartitionRequest)
|
|
if cpt.result == nil {
|
|
return errors.New("get collection statistics resp is nil")
|
|
}
|
|
if cpt.result.ErrorCode != commonpb.ErrorCode_Success {
|
|
return errors.New(cpt.result.Reason)
|
|
}
|
|
return err
|
|
}
|
|
|
|
func (cpt *CreatePartitionTask) PostExecute(ctx context.Context) error {
|
|
return nil
|
|
}
|
|
|
|
type DropPartitionTask struct {
|
|
Condition
|
|
*milvuspb.DropPartitionRequest
|
|
ctx context.Context
|
|
masterService types.MasterService
|
|
result *commonpb.Status
|
|
}
|
|
|
|
func (dpt *DropPartitionTask) TraceCtx() context.Context {
|
|
return dpt.ctx
|
|
}
|
|
|
|
func (dpt *DropPartitionTask) ID() UniqueID {
|
|
return dpt.Base.MsgID
|
|
}
|
|
|
|
func (dpt *DropPartitionTask) SetID(uid UniqueID) {
|
|
dpt.Base.MsgID = uid
|
|
}
|
|
|
|
func (dpt *DropPartitionTask) Name() string {
|
|
return DropPartitionTaskName
|
|
}
|
|
|
|
func (dpt *DropPartitionTask) Type() commonpb.MsgType {
|
|
return dpt.Base.MsgType
|
|
}
|
|
|
|
func (dpt *DropPartitionTask) BeginTs() Timestamp {
|
|
return dpt.Base.Timestamp
|
|
}
|
|
|
|
func (dpt *DropPartitionTask) EndTs() Timestamp {
|
|
return dpt.Base.Timestamp
|
|
}
|
|
|
|
func (dpt *DropPartitionTask) SetTs(ts Timestamp) {
|
|
dpt.Base.Timestamp = ts
|
|
}
|
|
|
|
func (dpt *DropPartitionTask) OnEnqueue() error {
|
|
dpt.Base = &commonpb.MsgBase{}
|
|
return nil
|
|
}
|
|
|
|
func (dpt *DropPartitionTask) PreExecute(ctx context.Context) error {
|
|
dpt.Base.MsgType = commonpb.MsgType_DropPartition
|
|
dpt.Base.SourceID = Params.ProxyID
|
|
|
|
collName, partitionTag := dpt.CollectionName, dpt.PartitionName
|
|
|
|
if err := ValidateCollectionName(collName); err != nil {
|
|
return err
|
|
}
|
|
|
|
if err := ValidatePartitionTag(partitionTag, true); err != nil {
|
|
return err
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
func (dpt *DropPartitionTask) Execute(ctx context.Context) (err error) {
|
|
dpt.result, err = dpt.masterService.DropPartition(ctx, dpt.DropPartitionRequest)
|
|
if dpt.result == nil {
|
|
return errors.New("get collection statistics resp is nil")
|
|
}
|
|
if dpt.result.ErrorCode != commonpb.ErrorCode_Success {
|
|
return errors.New(dpt.result.Reason)
|
|
}
|
|
return err
|
|
}
|
|
|
|
func (dpt *DropPartitionTask) PostExecute(ctx context.Context) error {
|
|
return nil
|
|
}
|
|
|
|
type HasPartitionTask struct {
|
|
Condition
|
|
*milvuspb.HasPartitionRequest
|
|
ctx context.Context
|
|
masterService types.MasterService
|
|
result *milvuspb.BoolResponse
|
|
}
|
|
|
|
func (hpt *HasPartitionTask) TraceCtx() context.Context {
|
|
return hpt.ctx
|
|
}
|
|
|
|
func (hpt *HasPartitionTask) ID() UniqueID {
|
|
return hpt.Base.MsgID
|
|
}
|
|
|
|
func (hpt *HasPartitionTask) SetID(uid UniqueID) {
|
|
hpt.Base.MsgID = uid
|
|
}
|
|
|
|
func (hpt *HasPartitionTask) Name() string {
|
|
return HasPartitionTaskName
|
|
}
|
|
|
|
func (hpt *HasPartitionTask) Type() commonpb.MsgType {
|
|
return hpt.Base.MsgType
|
|
}
|
|
|
|
func (hpt *HasPartitionTask) BeginTs() Timestamp {
|
|
return hpt.Base.Timestamp
|
|
}
|
|
|
|
func (hpt *HasPartitionTask) EndTs() Timestamp {
|
|
return hpt.Base.Timestamp
|
|
}
|
|
|
|
func (hpt *HasPartitionTask) SetTs(ts Timestamp) {
|
|
hpt.Base.Timestamp = ts
|
|
}
|
|
|
|
func (hpt *HasPartitionTask) OnEnqueue() error {
|
|
hpt.Base = &commonpb.MsgBase{}
|
|
return nil
|
|
}
|
|
|
|
func (hpt *HasPartitionTask) PreExecute(ctx context.Context) error {
|
|
hpt.Base.MsgType = commonpb.MsgType_HasPartition
|
|
hpt.Base.SourceID = Params.ProxyID
|
|
|
|
collName, partitionTag := hpt.CollectionName, hpt.PartitionName
|
|
|
|
if err := ValidateCollectionName(collName); err != nil {
|
|
return err
|
|
}
|
|
|
|
if err := ValidatePartitionTag(partitionTag, true); err != nil {
|
|
return err
|
|
}
|
|
return nil
|
|
}
|
|
|
|
func (hpt *HasPartitionTask) Execute(ctx context.Context) (err error) {
|
|
hpt.result, err = hpt.masterService.HasPartition(ctx, hpt.HasPartitionRequest)
|
|
if hpt.result == nil {
|
|
return errors.New("get collection statistics resp is nil")
|
|
}
|
|
if hpt.result.Status.ErrorCode != commonpb.ErrorCode_Success {
|
|
return errors.New(hpt.result.Status.Reason)
|
|
}
|
|
return err
|
|
}
|
|
|
|
func (hpt *HasPartitionTask) PostExecute(ctx context.Context) error {
|
|
return nil
|
|
}
|
|
|
|
type ShowPartitionsTask struct {
|
|
Condition
|
|
*milvuspb.ShowPartitionsRequest
|
|
ctx context.Context
|
|
masterService types.MasterService
|
|
result *milvuspb.ShowPartitionsResponse
|
|
}
|
|
|
|
func (spt *ShowPartitionsTask) TraceCtx() context.Context {
|
|
return spt.ctx
|
|
}
|
|
|
|
func (spt *ShowPartitionsTask) ID() UniqueID {
|
|
return spt.Base.MsgID
|
|
}
|
|
|
|
func (spt *ShowPartitionsTask) SetID(uid UniqueID) {
|
|
spt.Base.MsgID = uid
|
|
}
|
|
|
|
func (spt *ShowPartitionsTask) Name() string {
|
|
return ShowPartitionTaskName
|
|
}
|
|
|
|
func (spt *ShowPartitionsTask) Type() commonpb.MsgType {
|
|
return spt.Base.MsgType
|
|
}
|
|
|
|
func (spt *ShowPartitionsTask) BeginTs() Timestamp {
|
|
return spt.Base.Timestamp
|
|
}
|
|
|
|
func (spt *ShowPartitionsTask) EndTs() Timestamp {
|
|
return spt.Base.Timestamp
|
|
}
|
|
|
|
func (spt *ShowPartitionsTask) SetTs(ts Timestamp) {
|
|
spt.Base.Timestamp = ts
|
|
}
|
|
|
|
func (spt *ShowPartitionsTask) OnEnqueue() error {
|
|
spt.Base = &commonpb.MsgBase{}
|
|
return nil
|
|
}
|
|
|
|
func (spt *ShowPartitionsTask) PreExecute(ctx context.Context) error {
|
|
spt.Base.MsgType = commonpb.MsgType_ShowPartitions
|
|
spt.Base.SourceID = Params.ProxyID
|
|
|
|
if err := ValidateCollectionName(spt.CollectionName); err != nil {
|
|
return err
|
|
}
|
|
return nil
|
|
}
|
|
|
|
func (spt *ShowPartitionsTask) Execute(ctx context.Context) error {
|
|
var err error
|
|
spt.result, err = spt.masterService.ShowPartitions(ctx, spt.ShowPartitionsRequest)
|
|
if spt.result == nil {
|
|
return errors.New("get collection statistics resp is nil")
|
|
}
|
|
if spt.result.Status.ErrorCode != commonpb.ErrorCode_Success {
|
|
return errors.New(spt.result.Status.Reason)
|
|
}
|
|
return err
|
|
}
|
|
|
|
func (spt *ShowPartitionsTask) PostExecute(ctx context.Context) error {
|
|
return nil
|
|
}
|
|
|
|
type CreateIndexTask struct {
|
|
Condition
|
|
*milvuspb.CreateIndexRequest
|
|
ctx context.Context
|
|
masterService types.MasterService
|
|
result *commonpb.Status
|
|
}
|
|
|
|
func (cit *CreateIndexTask) TraceCtx() context.Context {
|
|
return cit.ctx
|
|
}
|
|
|
|
func (cit *CreateIndexTask) ID() UniqueID {
|
|
return cit.Base.MsgID
|
|
}
|
|
|
|
func (cit *CreateIndexTask) SetID(uid UniqueID) {
|
|
cit.Base.MsgID = uid
|
|
}
|
|
|
|
func (cit *CreateIndexTask) Name() string {
|
|
return CreateIndexTaskName
|
|
}
|
|
|
|
func (cit *CreateIndexTask) Type() commonpb.MsgType {
|
|
return cit.Base.MsgType
|
|
}
|
|
|
|
func (cit *CreateIndexTask) BeginTs() Timestamp {
|
|
return cit.Base.Timestamp
|
|
}
|
|
|
|
func (cit *CreateIndexTask) EndTs() Timestamp {
|
|
return cit.Base.Timestamp
|
|
}
|
|
|
|
func (cit *CreateIndexTask) SetTs(ts Timestamp) {
|
|
cit.Base.Timestamp = ts
|
|
}
|
|
|
|
func (cit *CreateIndexTask) OnEnqueue() error {
|
|
cit.Base = &commonpb.MsgBase{}
|
|
return nil
|
|
}
|
|
|
|
func (cit *CreateIndexTask) PreExecute(ctx context.Context) error {
|
|
cit.Base.MsgType = commonpb.MsgType_CreateIndex
|
|
cit.Base.SourceID = Params.ProxyID
|
|
|
|
collName, fieldName := cit.CollectionName, cit.FieldName
|
|
|
|
if err := ValidateCollectionName(collName); err != nil {
|
|
return err
|
|
}
|
|
|
|
if err := ValidateFieldName(fieldName); err != nil {
|
|
return err
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
func (cit *CreateIndexTask) Execute(ctx context.Context) error {
|
|
var err error
|
|
cit.result, err = cit.masterService.CreateIndex(ctx, cit.CreateIndexRequest)
|
|
if cit.result == nil {
|
|
return errors.New("get collection statistics resp is nil")
|
|
}
|
|
if cit.result.ErrorCode != commonpb.ErrorCode_Success {
|
|
return errors.New(cit.result.Reason)
|
|
}
|
|
return err
|
|
}
|
|
|
|
func (cit *CreateIndexTask) PostExecute(ctx context.Context) error {
|
|
return nil
|
|
}
|
|
|
|
type DescribeIndexTask struct {
|
|
Condition
|
|
*milvuspb.DescribeIndexRequest
|
|
ctx context.Context
|
|
masterService types.MasterService
|
|
result *milvuspb.DescribeIndexResponse
|
|
}
|
|
|
|
func (dit *DescribeIndexTask) TraceCtx() context.Context {
|
|
return dit.ctx
|
|
}
|
|
|
|
func (dit *DescribeIndexTask) ID() UniqueID {
|
|
return dit.Base.MsgID
|
|
}
|
|
|
|
func (dit *DescribeIndexTask) SetID(uid UniqueID) {
|
|
dit.Base.MsgID = uid
|
|
}
|
|
|
|
func (dit *DescribeIndexTask) Name() string {
|
|
return DescribeIndexTaskName
|
|
}
|
|
|
|
func (dit *DescribeIndexTask) Type() commonpb.MsgType {
|
|
return dit.Base.MsgType
|
|
}
|
|
|
|
func (dit *DescribeIndexTask) BeginTs() Timestamp {
|
|
return dit.Base.Timestamp
|
|
}
|
|
|
|
func (dit *DescribeIndexTask) EndTs() Timestamp {
|
|
return dit.Base.Timestamp
|
|
}
|
|
|
|
func (dit *DescribeIndexTask) SetTs(ts Timestamp) {
|
|
dit.Base.Timestamp = ts
|
|
}
|
|
|
|
func (dit *DescribeIndexTask) OnEnqueue() error {
|
|
dit.Base = &commonpb.MsgBase{}
|
|
return nil
|
|
}
|
|
|
|
func (dit *DescribeIndexTask) PreExecute(ctx context.Context) error {
|
|
dit.Base.MsgType = commonpb.MsgType_DescribeIndex
|
|
dit.Base.SourceID = Params.ProxyID
|
|
|
|
if err := ValidateCollectionName(dit.CollectionName); err != nil {
|
|
return err
|
|
}
|
|
|
|
// only support default index name for now. @2021.02.18
|
|
if dit.IndexName == "" {
|
|
dit.IndexName = Params.DefaultIndexName
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
func (dit *DescribeIndexTask) Execute(ctx context.Context) error {
|
|
var err error
|
|
dit.result, err = dit.masterService.DescribeIndex(ctx, dit.DescribeIndexRequest)
|
|
if dit.result == nil {
|
|
return errors.New("get collection statistics resp is nil")
|
|
}
|
|
if dit.result.Status.ErrorCode != commonpb.ErrorCode_Success {
|
|
return errors.New(dit.result.Status.Reason)
|
|
}
|
|
return err
|
|
}
|
|
|
|
func (dit *DescribeIndexTask) PostExecute(ctx context.Context) error {
|
|
return nil
|
|
}
|
|
|
|
type DropIndexTask struct {
|
|
Condition
|
|
ctx context.Context
|
|
*milvuspb.DropIndexRequest
|
|
masterService types.MasterService
|
|
result *commonpb.Status
|
|
}
|
|
|
|
func (dit *DropIndexTask) TraceCtx() context.Context {
|
|
return dit.ctx
|
|
}
|
|
|
|
func (dit *DropIndexTask) ID() UniqueID {
|
|
return dit.Base.MsgID
|
|
}
|
|
|
|
func (dit *DropIndexTask) SetID(uid UniqueID) {
|
|
dit.Base.MsgID = uid
|
|
}
|
|
|
|
func (dit *DropIndexTask) Name() string {
|
|
return DropIndexTaskName
|
|
}
|
|
|
|
func (dit *DropIndexTask) Type() commonpb.MsgType {
|
|
return dit.Base.MsgType
|
|
}
|
|
|
|
func (dit *DropIndexTask) BeginTs() Timestamp {
|
|
return dit.Base.Timestamp
|
|
}
|
|
|
|
func (dit *DropIndexTask) EndTs() Timestamp {
|
|
return dit.Base.Timestamp
|
|
}
|
|
|
|
func (dit *DropIndexTask) SetTs(ts Timestamp) {
|
|
dit.Base.Timestamp = ts
|
|
}
|
|
|
|
func (dit *DropIndexTask) OnEnqueue() error {
|
|
dit.Base = &commonpb.MsgBase{}
|
|
return nil
|
|
}
|
|
|
|
func (dit *DropIndexTask) PreExecute(ctx context.Context) error {
|
|
dit.Base.MsgType = commonpb.MsgType_DropIndex
|
|
dit.Base.SourceID = Params.ProxyID
|
|
|
|
collName, fieldName := dit.CollectionName, dit.FieldName
|
|
|
|
if err := ValidateCollectionName(collName); err != nil {
|
|
return err
|
|
}
|
|
|
|
if err := ValidateFieldName(fieldName); err != nil {
|
|
return err
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
func (dit *DropIndexTask) Execute(ctx context.Context) error {
|
|
var err error
|
|
dit.result, err = dit.masterService.DropIndex(ctx, dit.DropIndexRequest)
|
|
if dit.result == nil {
|
|
return errors.New("drop index resp is nil")
|
|
}
|
|
if dit.result.ErrorCode != commonpb.ErrorCode_Success {
|
|
return errors.New(dit.result.Reason)
|
|
}
|
|
return err
|
|
}
|
|
|
|
func (dit *DropIndexTask) PostExecute(ctx context.Context) error {
|
|
return nil
|
|
}
|
|
|
|
type GetIndexBuildProgressTask struct {
|
|
Condition
|
|
*milvuspb.GetIndexBuildProgressRequest
|
|
ctx context.Context
|
|
indexService types.IndexService
|
|
masterService types.MasterService
|
|
dataService types.DataService
|
|
result *milvuspb.GetIndexBuildProgressResponse
|
|
}
|
|
|
|
func (gibpt *GetIndexBuildProgressTask) TraceCtx() context.Context {
|
|
return gibpt.ctx
|
|
}
|
|
|
|
func (gibpt *GetIndexBuildProgressTask) ID() UniqueID {
|
|
return gibpt.Base.MsgID
|
|
}
|
|
|
|
func (gibpt *GetIndexBuildProgressTask) SetID(uid UniqueID) {
|
|
gibpt.Base.MsgID = uid
|
|
}
|
|
|
|
func (gibpt *GetIndexBuildProgressTask) Name() string {
|
|
return GetIndexBuildProgressTaskName
|
|
}
|
|
|
|
func (gibpt *GetIndexBuildProgressTask) Type() commonpb.MsgType {
|
|
return gibpt.Base.MsgType
|
|
}
|
|
|
|
func (gibpt *GetIndexBuildProgressTask) BeginTs() Timestamp {
|
|
return gibpt.Base.Timestamp
|
|
}
|
|
|
|
func (gibpt *GetIndexBuildProgressTask) EndTs() Timestamp {
|
|
return gibpt.Base.Timestamp
|
|
}
|
|
|
|
func (gibpt *GetIndexBuildProgressTask) SetTs(ts Timestamp) {
|
|
gibpt.Base.Timestamp = ts
|
|
}
|
|
|
|
func (gibpt *GetIndexBuildProgressTask) OnEnqueue() error {
|
|
gibpt.Base = &commonpb.MsgBase{}
|
|
return nil
|
|
}
|
|
|
|
func (gibpt *GetIndexBuildProgressTask) PreExecute(ctx context.Context) error {
|
|
gibpt.Base.MsgType = commonpb.MsgType_GetIndexBuildProgress
|
|
gibpt.Base.SourceID = Params.ProxyID
|
|
|
|
if err := ValidateCollectionName(gibpt.CollectionName); err != nil {
|
|
return err
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
func (gibpt *GetIndexBuildProgressTask) Execute(ctx context.Context) error {
|
|
collectionName := gibpt.CollectionName
|
|
collectionID, err := globalMetaCache.GetCollectionID(ctx, collectionName)
|
|
if err != nil { // err is not nil if collection not exists
|
|
return err
|
|
}
|
|
|
|
showPartitionRequest := &milvuspb.ShowPartitionsRequest{
|
|
Base: &commonpb.MsgBase{
|
|
MsgType: commonpb.MsgType_ShowPartitions,
|
|
MsgID: gibpt.Base.MsgID,
|
|
Timestamp: gibpt.Base.Timestamp,
|
|
SourceID: Params.ProxyID,
|
|
},
|
|
DbName: gibpt.DbName,
|
|
CollectionName: collectionName,
|
|
CollectionID: collectionID,
|
|
}
|
|
partitions, err := gibpt.masterService.ShowPartitions(ctx, showPartitionRequest)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
if gibpt.IndexName == "" {
|
|
gibpt.IndexName = Params.DefaultIndexName
|
|
}
|
|
|
|
describeIndexReq := milvuspb.DescribeIndexRequest{
|
|
Base: &commonpb.MsgBase{
|
|
MsgType: commonpb.MsgType_DescribeIndex,
|
|
MsgID: gibpt.Base.MsgID,
|
|
Timestamp: gibpt.Base.Timestamp,
|
|
SourceID: Params.ProxyID,
|
|
},
|
|
DbName: gibpt.DbName,
|
|
CollectionName: gibpt.CollectionName,
|
|
// IndexName: gibpt.IndexName,
|
|
}
|
|
|
|
indexDescriptionResp, err2 := gibpt.masterService.DescribeIndex(ctx, &describeIndexReq)
|
|
if err2 != nil {
|
|
return err2
|
|
}
|
|
|
|
matchIndexID := int64(-1)
|
|
foundIndexID := false
|
|
for _, desc := range indexDescriptionResp.IndexDescriptions {
|
|
if desc.IndexName == gibpt.IndexName {
|
|
matchIndexID = desc.IndexID
|
|
foundIndexID = true
|
|
break
|
|
}
|
|
}
|
|
if !foundIndexID {
|
|
return errors.New(fmt.Sprint("Can't found IndexID for indexName", gibpt.IndexName))
|
|
}
|
|
|
|
var allSegmentIDs []UniqueID
|
|
for _, partitionID := range partitions.PartitionIDs {
|
|
showSegmentsRequest := &milvuspb.ShowSegmentsRequest{
|
|
Base: &commonpb.MsgBase{
|
|
MsgType: commonpb.MsgType_ShowSegments,
|
|
MsgID: gibpt.Base.MsgID,
|
|
Timestamp: gibpt.Base.Timestamp,
|
|
SourceID: Params.ProxyID,
|
|
},
|
|
CollectionID: collectionID,
|
|
PartitionID: partitionID,
|
|
}
|
|
segments, err := gibpt.masterService.ShowSegments(ctx, showSegmentsRequest)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
if segments.Status.ErrorCode != commonpb.ErrorCode_Success {
|
|
return errors.New(segments.Status.Reason)
|
|
}
|
|
allSegmentIDs = append(allSegmentIDs, segments.SegmentIDs...)
|
|
}
|
|
|
|
getIndexStatesRequest := &indexpb.GetIndexStatesRequest{
|
|
IndexBuildIDs: make([]UniqueID, 0),
|
|
}
|
|
|
|
buildIndexMap := make(map[int64]int64)
|
|
for _, segmentID := range allSegmentIDs {
|
|
describeSegmentRequest := &milvuspb.DescribeSegmentRequest{
|
|
Base: &commonpb.MsgBase{
|
|
MsgType: commonpb.MsgType_DescribeSegment,
|
|
MsgID: gibpt.Base.MsgID,
|
|
Timestamp: gibpt.Base.Timestamp,
|
|
SourceID: Params.ProxyID,
|
|
},
|
|
CollectionID: collectionID,
|
|
SegmentID: segmentID,
|
|
}
|
|
segmentDesc, err := gibpt.masterService.DescribeSegment(ctx, describeSegmentRequest)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
if segmentDesc.IndexID == matchIndexID {
|
|
if segmentDesc.EnableIndex {
|
|
getIndexStatesRequest.IndexBuildIDs = append(getIndexStatesRequest.IndexBuildIDs, segmentDesc.BuildID)
|
|
buildIndexMap[segmentID] = segmentDesc.BuildID
|
|
}
|
|
}
|
|
}
|
|
|
|
states, err := gibpt.indexService.GetIndexStates(ctx, getIndexStatesRequest)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
if states.Status.ErrorCode != commonpb.ErrorCode_Success {
|
|
gibpt.result = &milvuspb.GetIndexBuildProgressResponse{
|
|
Status: states.Status,
|
|
}
|
|
}
|
|
|
|
buildFinishMap := make(map[int64]bool)
|
|
for _, state := range states.States {
|
|
if state.State == commonpb.IndexState_Finished {
|
|
buildFinishMap[state.IndexBuildID] = true
|
|
}
|
|
}
|
|
|
|
infoResp, err := gibpt.dataService.GetSegmentInfo(ctx, &datapb.GetSegmentInfoRequest{
|
|
Base: &commonpb.MsgBase{
|
|
MsgType: commonpb.MsgType_SegmentInfo,
|
|
MsgID: 0,
|
|
Timestamp: 0,
|
|
SourceID: Params.ProxyID,
|
|
},
|
|
SegmentIDs: allSegmentIDs,
|
|
})
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
total := int64(0)
|
|
indexed := int64(0)
|
|
|
|
for _, info := range infoResp.Infos {
|
|
total += info.NumRows
|
|
if buildFinishMap[buildIndexMap[info.ID]] {
|
|
indexed += info.NumRows
|
|
}
|
|
}
|
|
|
|
gibpt.result = &milvuspb.GetIndexBuildProgressResponse{
|
|
Status: &commonpb.Status{
|
|
ErrorCode: commonpb.ErrorCode_Success,
|
|
Reason: "",
|
|
},
|
|
TotalRows: total,
|
|
IndexedRows: indexed,
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
func (gibpt *GetIndexBuildProgressTask) PostExecute(ctx context.Context) error {
|
|
return nil
|
|
}
|
|
|
|
type GetIndexStateTask struct {
|
|
Condition
|
|
*milvuspb.GetIndexStateRequest
|
|
ctx context.Context
|
|
indexService types.IndexService
|
|
masterService types.MasterService
|
|
result *milvuspb.GetIndexStateResponse
|
|
}
|
|
|
|
func (gist *GetIndexStateTask) TraceCtx() context.Context {
|
|
return gist.ctx
|
|
}
|
|
|
|
func (gist *GetIndexStateTask) ID() UniqueID {
|
|
return gist.Base.MsgID
|
|
}
|
|
|
|
func (gist *GetIndexStateTask) SetID(uid UniqueID) {
|
|
gist.Base.MsgID = uid
|
|
}
|
|
|
|
func (gist *GetIndexStateTask) Name() string {
|
|
return GetIndexStateTaskName
|
|
}
|
|
|
|
func (gist *GetIndexStateTask) Type() commonpb.MsgType {
|
|
return gist.Base.MsgType
|
|
}
|
|
|
|
func (gist *GetIndexStateTask) BeginTs() Timestamp {
|
|
return gist.Base.Timestamp
|
|
}
|
|
|
|
func (gist *GetIndexStateTask) EndTs() Timestamp {
|
|
return gist.Base.Timestamp
|
|
}
|
|
|
|
func (gist *GetIndexStateTask) SetTs(ts Timestamp) {
|
|
gist.Base.Timestamp = ts
|
|
}
|
|
|
|
func (gist *GetIndexStateTask) OnEnqueue() error {
|
|
gist.Base = &commonpb.MsgBase{}
|
|
return nil
|
|
}
|
|
|
|
func (gist *GetIndexStateTask) PreExecute(ctx context.Context) error {
|
|
gist.Base.MsgType = commonpb.MsgType_GetIndexState
|
|
gist.Base.SourceID = Params.ProxyID
|
|
|
|
if err := ValidateCollectionName(gist.CollectionName); err != nil {
|
|
return err
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
func (gist *GetIndexStateTask) Execute(ctx context.Context) error {
|
|
collectionName := gist.CollectionName
|
|
collectionID, err := globalMetaCache.GetCollectionID(ctx, collectionName)
|
|
if err != nil { // err is not nil if collection not exists
|
|
return err
|
|
}
|
|
|
|
showPartitionRequest := &milvuspb.ShowPartitionsRequest{
|
|
Base: &commonpb.MsgBase{
|
|
MsgType: commonpb.MsgType_ShowPartitions,
|
|
MsgID: gist.Base.MsgID,
|
|
Timestamp: gist.Base.Timestamp,
|
|
SourceID: Params.ProxyID,
|
|
},
|
|
DbName: gist.DbName,
|
|
CollectionName: collectionName,
|
|
CollectionID: collectionID,
|
|
}
|
|
partitions, err := gist.masterService.ShowPartitions(ctx, showPartitionRequest)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
if gist.IndexName == "" {
|
|
gist.IndexName = Params.DefaultIndexName
|
|
}
|
|
|
|
describeIndexReq := milvuspb.DescribeIndexRequest{
|
|
Base: &commonpb.MsgBase{
|
|
MsgType: commonpb.MsgType_DescribeIndex,
|
|
MsgID: gist.Base.MsgID,
|
|
Timestamp: gist.Base.Timestamp,
|
|
SourceID: Params.ProxyID,
|
|
},
|
|
DbName: gist.DbName,
|
|
CollectionName: gist.CollectionName,
|
|
IndexName: gist.IndexName,
|
|
}
|
|
|
|
indexDescriptionResp, err2 := gist.masterService.DescribeIndex(ctx, &describeIndexReq)
|
|
if err2 != nil {
|
|
return err2
|
|
}
|
|
|
|
matchIndexID := int64(-1)
|
|
foundIndexID := false
|
|
for _, desc := range indexDescriptionResp.IndexDescriptions {
|
|
if desc.IndexName == gist.IndexName {
|
|
matchIndexID = desc.IndexID
|
|
foundIndexID = true
|
|
break
|
|
}
|
|
}
|
|
if !foundIndexID {
|
|
return errors.New(fmt.Sprint("Can't found IndexID for indexName", gist.IndexName))
|
|
}
|
|
|
|
var allSegmentIDs []UniqueID
|
|
for _, partitionID := range partitions.PartitionIDs {
|
|
showSegmentsRequest := &milvuspb.ShowSegmentsRequest{
|
|
Base: &commonpb.MsgBase{
|
|
MsgType: commonpb.MsgType_ShowSegments,
|
|
MsgID: gist.Base.MsgID,
|
|
Timestamp: gist.Base.Timestamp,
|
|
SourceID: Params.ProxyID,
|
|
},
|
|
CollectionID: collectionID,
|
|
PartitionID: partitionID,
|
|
}
|
|
segments, err := gist.masterService.ShowSegments(ctx, showSegmentsRequest)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
if segments.Status.ErrorCode != commonpb.ErrorCode_Success {
|
|
return errors.New(segments.Status.Reason)
|
|
}
|
|
allSegmentIDs = append(allSegmentIDs, segments.SegmentIDs...)
|
|
}
|
|
|
|
getIndexStatesRequest := &indexpb.GetIndexStatesRequest{
|
|
IndexBuildIDs: make([]UniqueID, 0),
|
|
}
|
|
enableIndexBitMap := make([]bool, 0)
|
|
indexBuildIDs := make([]UniqueID, 0)
|
|
|
|
for _, segmentID := range allSegmentIDs {
|
|
describeSegmentRequest := &milvuspb.DescribeSegmentRequest{
|
|
Base: &commonpb.MsgBase{
|
|
MsgType: commonpb.MsgType_DescribeSegment,
|
|
MsgID: gist.Base.MsgID,
|
|
Timestamp: gist.Base.Timestamp,
|
|
SourceID: Params.ProxyID,
|
|
},
|
|
CollectionID: collectionID,
|
|
SegmentID: segmentID,
|
|
}
|
|
segmentDesc, err := gist.masterService.DescribeSegment(ctx, describeSegmentRequest)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
if segmentDesc.IndexID == matchIndexID {
|
|
indexBuildIDs = append(indexBuildIDs, segmentDesc.BuildID)
|
|
if segmentDesc.EnableIndex {
|
|
enableIndexBitMap = append(enableIndexBitMap, true)
|
|
} else {
|
|
enableIndexBitMap = append(enableIndexBitMap, false)
|
|
}
|
|
}
|
|
}
|
|
|
|
log.Debug("proxynode", zap.Int("GetIndexState:: len of allSegmentIDs", len(allSegmentIDs)))
|
|
log.Debug("proxynode", zap.Int("GetIndexState:: len of IndexBuildIDs", len(indexBuildIDs)))
|
|
if len(allSegmentIDs) != len(indexBuildIDs) {
|
|
gist.result = &milvuspb.GetIndexStateResponse{
|
|
Status: &commonpb.Status{
|
|
ErrorCode: commonpb.ErrorCode_Success,
|
|
Reason: "",
|
|
},
|
|
State: commonpb.IndexState_InProgress,
|
|
}
|
|
return err
|
|
}
|
|
|
|
for idx, enableIndex := range enableIndexBitMap {
|
|
if enableIndex {
|
|
getIndexStatesRequest.IndexBuildIDs = append(getIndexStatesRequest.IndexBuildIDs, indexBuildIDs[idx])
|
|
}
|
|
}
|
|
states, err := gist.indexService.GetIndexStates(ctx, getIndexStatesRequest)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
if states.Status.ErrorCode != commonpb.ErrorCode_Success {
|
|
gist.result = &milvuspb.GetIndexStateResponse{
|
|
Status: states.Status,
|
|
State: commonpb.IndexState_Failed,
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
for _, state := range states.States {
|
|
if state.State != commonpb.IndexState_Finished {
|
|
gist.result = &milvuspb.GetIndexStateResponse{
|
|
Status: states.Status,
|
|
State: state.State,
|
|
}
|
|
|
|
return nil
|
|
}
|
|
}
|
|
|
|
gist.result = &milvuspb.GetIndexStateResponse{
|
|
Status: &commonpb.Status{
|
|
ErrorCode: commonpb.ErrorCode_Success,
|
|
Reason: "",
|
|
},
|
|
State: commonpb.IndexState_Finished,
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
func (gist *GetIndexStateTask) PostExecute(ctx context.Context) error {
|
|
return nil
|
|
}
|
|
|
|
type FlushTask struct {
|
|
Condition
|
|
*milvuspb.FlushRequest
|
|
ctx context.Context
|
|
dataService types.DataService
|
|
result *commonpb.Status
|
|
}
|
|
|
|
func (ft *FlushTask) TraceCtx() context.Context {
|
|
return ft.ctx
|
|
}
|
|
|
|
func (ft *FlushTask) ID() UniqueID {
|
|
return ft.Base.MsgID
|
|
}
|
|
|
|
func (ft *FlushTask) SetID(uid UniqueID) {
|
|
ft.Base.MsgID = uid
|
|
}
|
|
|
|
func (ft *FlushTask) Name() string {
|
|
return FlushTaskName
|
|
}
|
|
|
|
func (ft *FlushTask) Type() commonpb.MsgType {
|
|
return ft.Base.MsgType
|
|
}
|
|
|
|
func (ft *FlushTask) BeginTs() Timestamp {
|
|
return ft.Base.Timestamp
|
|
}
|
|
|
|
func (ft *FlushTask) EndTs() Timestamp {
|
|
return ft.Base.Timestamp
|
|
}
|
|
|
|
func (ft *FlushTask) SetTs(ts Timestamp) {
|
|
ft.Base.Timestamp = ts
|
|
}
|
|
|
|
func (ft *FlushTask) OnEnqueue() error {
|
|
ft.Base = &commonpb.MsgBase{}
|
|
return nil
|
|
}
|
|
|
|
func (ft *FlushTask) PreExecute(ctx context.Context) error {
|
|
ft.Base.MsgType = commonpb.MsgType_Flush
|
|
ft.Base.SourceID = Params.ProxyID
|
|
return nil
|
|
}
|
|
|
|
func (ft *FlushTask) Execute(ctx context.Context) error {
|
|
for _, collName := range ft.CollectionNames {
|
|
collID, err := globalMetaCache.GetCollectionID(ctx, collName)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
flushReq := &datapb.FlushRequest{
|
|
Base: &commonpb.MsgBase{
|
|
MsgType: commonpb.MsgType_Flush,
|
|
MsgID: ft.Base.MsgID,
|
|
Timestamp: ft.Base.Timestamp,
|
|
SourceID: ft.Base.SourceID,
|
|
},
|
|
DbID: 0,
|
|
CollectionID: collID,
|
|
}
|
|
var status *commonpb.Status
|
|
status, _ = ft.dataService.Flush(ctx, flushReq)
|
|
if status == nil {
|
|
return errors.New("flush resp is nil")
|
|
}
|
|
if status.ErrorCode != commonpb.ErrorCode_Success {
|
|
return errors.New(status.Reason)
|
|
}
|
|
}
|
|
ft.result = &commonpb.Status{
|
|
ErrorCode: commonpb.ErrorCode_Success,
|
|
}
|
|
return nil
|
|
}
|
|
|
|
func (ft *FlushTask) PostExecute(ctx context.Context) error {
|
|
return nil
|
|
}
|
|
|
|
type LoadCollectionTask struct {
|
|
Condition
|
|
*milvuspb.LoadCollectionRequest
|
|
ctx context.Context
|
|
queryService types.QueryService
|
|
result *commonpb.Status
|
|
}
|
|
|
|
func (lct *LoadCollectionTask) TraceCtx() context.Context {
|
|
return lct.ctx
|
|
}
|
|
|
|
func (lct *LoadCollectionTask) ID() UniqueID {
|
|
return lct.Base.MsgID
|
|
}
|
|
|
|
func (lct *LoadCollectionTask) SetID(uid UniqueID) {
|
|
lct.Base.MsgID = uid
|
|
}
|
|
|
|
func (lct *LoadCollectionTask) Name() string {
|
|
return LoadCollectionTaskName
|
|
}
|
|
|
|
func (lct *LoadCollectionTask) Type() commonpb.MsgType {
|
|
return lct.Base.MsgType
|
|
}
|
|
|
|
func (lct *LoadCollectionTask) BeginTs() Timestamp {
|
|
return lct.Base.Timestamp
|
|
}
|
|
|
|
func (lct *LoadCollectionTask) EndTs() Timestamp {
|
|
return lct.Base.Timestamp
|
|
}
|
|
|
|
func (lct *LoadCollectionTask) SetTs(ts Timestamp) {
|
|
lct.Base.Timestamp = ts
|
|
}
|
|
|
|
func (lct *LoadCollectionTask) OnEnqueue() error {
|
|
lct.Base = &commonpb.MsgBase{}
|
|
return nil
|
|
}
|
|
|
|
func (lct *LoadCollectionTask) PreExecute(ctx context.Context) error {
|
|
log.Debug("LoadCollectionTask PreExecute", zap.String("role", Params.RoleName), zap.Int64("msgID", lct.Base.MsgID))
|
|
lct.Base.MsgType = commonpb.MsgType_LoadCollection
|
|
lct.Base.SourceID = Params.ProxyID
|
|
|
|
collName := lct.CollectionName
|
|
|
|
if err := ValidateCollectionName(collName); err != nil {
|
|
return err
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
func (lct *LoadCollectionTask) Execute(ctx context.Context) (err error) {
|
|
log.Debug("LoadCollectionTask Execute", zap.String("role", Params.RoleName), zap.Int64("msgID", lct.Base.MsgID))
|
|
collID, err := globalMetaCache.GetCollectionID(ctx, lct.CollectionName)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
collSchema, err := globalMetaCache.GetCollectionSchema(ctx, lct.CollectionName)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
request := &querypb.LoadCollectionRequest{
|
|
Base: &commonpb.MsgBase{
|
|
MsgType: commonpb.MsgType_LoadCollection,
|
|
MsgID: lct.Base.MsgID,
|
|
Timestamp: lct.Base.Timestamp,
|
|
SourceID: lct.Base.SourceID,
|
|
},
|
|
DbID: 0,
|
|
CollectionID: collID,
|
|
Schema: collSchema,
|
|
}
|
|
log.Debug("send LoadCollectionRequest to query service", zap.String("role", Params.RoleName), zap.Int64("msgID", request.Base.MsgID), zap.Int64("collectionID", request.CollectionID),
|
|
zap.Any("schema", request.Schema))
|
|
lct.result, err = lct.queryService.LoadCollection(ctx, request)
|
|
if err != nil {
|
|
return fmt.Errorf("call query service LoadCollection: %s", err)
|
|
}
|
|
return nil
|
|
}
|
|
|
|
func (lct *LoadCollectionTask) PostExecute(ctx context.Context) error {
|
|
log.Debug("LoadCollectionTask PostExecute", zap.String("role", Params.RoleName), zap.Int64("msgID", lct.Base.MsgID))
|
|
return nil
|
|
}
|
|
|
|
type ReleaseCollectionTask struct {
|
|
Condition
|
|
*milvuspb.ReleaseCollectionRequest
|
|
ctx context.Context
|
|
queryService types.QueryService
|
|
result *commonpb.Status
|
|
}
|
|
|
|
func (rct *ReleaseCollectionTask) TraceCtx() context.Context {
|
|
return rct.ctx
|
|
}
|
|
|
|
func (rct *ReleaseCollectionTask) ID() UniqueID {
|
|
return rct.Base.MsgID
|
|
}
|
|
|
|
func (rct *ReleaseCollectionTask) SetID(uid UniqueID) {
|
|
rct.Base.MsgID = uid
|
|
}
|
|
|
|
func (rct *ReleaseCollectionTask) Name() string {
|
|
return ReleaseCollectionTaskName
|
|
}
|
|
|
|
func (rct *ReleaseCollectionTask) Type() commonpb.MsgType {
|
|
return rct.Base.MsgType
|
|
}
|
|
|
|
func (rct *ReleaseCollectionTask) BeginTs() Timestamp {
|
|
return rct.Base.Timestamp
|
|
}
|
|
|
|
func (rct *ReleaseCollectionTask) EndTs() Timestamp {
|
|
return rct.Base.Timestamp
|
|
}
|
|
|
|
func (rct *ReleaseCollectionTask) SetTs(ts Timestamp) {
|
|
rct.Base.Timestamp = ts
|
|
}
|
|
|
|
func (rct *ReleaseCollectionTask) OnEnqueue() error {
|
|
rct.Base = &commonpb.MsgBase{}
|
|
return nil
|
|
}
|
|
|
|
func (rct *ReleaseCollectionTask) PreExecute(ctx context.Context) error {
|
|
rct.Base.MsgType = commonpb.MsgType_ReleaseCollection
|
|
rct.Base.SourceID = Params.ProxyID
|
|
|
|
collName := rct.CollectionName
|
|
|
|
if err := ValidateCollectionName(collName); err != nil {
|
|
return err
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
func (rct *ReleaseCollectionTask) Execute(ctx context.Context) (err error) {
|
|
collID, err := globalMetaCache.GetCollectionID(ctx, rct.CollectionName)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
request := &querypb.ReleaseCollectionRequest{
|
|
Base: &commonpb.MsgBase{
|
|
MsgType: commonpb.MsgType_ReleaseCollection,
|
|
MsgID: rct.Base.MsgID,
|
|
Timestamp: rct.Base.Timestamp,
|
|
SourceID: rct.Base.SourceID,
|
|
},
|
|
DbID: 0,
|
|
CollectionID: collID,
|
|
}
|
|
rct.result, err = rct.queryService.ReleaseCollection(ctx, request)
|
|
return err
|
|
}
|
|
|
|
func (rct *ReleaseCollectionTask) PostExecute(ctx context.Context) error {
|
|
return nil
|
|
}
|
|
|
|
type LoadPartitionTask struct {
|
|
Condition
|
|
*milvuspb.LoadPartitionsRequest
|
|
ctx context.Context
|
|
queryService types.QueryService
|
|
result *commonpb.Status
|
|
}
|
|
|
|
func (lpt *LoadPartitionTask) TraceCtx() context.Context {
|
|
return lpt.ctx
|
|
}
|
|
|
|
func (lpt *LoadPartitionTask) ID() UniqueID {
|
|
return lpt.Base.MsgID
|
|
}
|
|
|
|
func (lpt *LoadPartitionTask) SetID(uid UniqueID) {
|
|
lpt.Base.MsgID = uid
|
|
}
|
|
|
|
func (lpt *LoadPartitionTask) Name() string {
|
|
return LoadPartitionTaskName
|
|
}
|
|
|
|
func (lpt *LoadPartitionTask) Type() commonpb.MsgType {
|
|
return lpt.Base.MsgType
|
|
}
|
|
|
|
func (lpt *LoadPartitionTask) BeginTs() Timestamp {
|
|
return lpt.Base.Timestamp
|
|
}
|
|
|
|
func (lpt *LoadPartitionTask) EndTs() Timestamp {
|
|
return lpt.Base.Timestamp
|
|
}
|
|
|
|
func (lpt *LoadPartitionTask) SetTs(ts Timestamp) {
|
|
lpt.Base.Timestamp = ts
|
|
}
|
|
|
|
func (lpt *LoadPartitionTask) OnEnqueue() error {
|
|
lpt.Base = &commonpb.MsgBase{}
|
|
return nil
|
|
}
|
|
|
|
func (lpt *LoadPartitionTask) PreExecute(ctx context.Context) error {
|
|
lpt.Base.MsgType = commonpb.MsgType_LoadPartitions
|
|
lpt.Base.SourceID = Params.ProxyID
|
|
|
|
collName := lpt.CollectionName
|
|
|
|
if err := ValidateCollectionName(collName); err != nil {
|
|
return err
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
func (lpt *LoadPartitionTask) Execute(ctx context.Context) error {
|
|
var partitionIDs []int64
|
|
collID, err := globalMetaCache.GetCollectionID(ctx, lpt.CollectionName)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
collSchema, err := globalMetaCache.GetCollectionSchema(ctx, lpt.CollectionName)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
for _, partitionName := range lpt.PartitionNames {
|
|
partitionID, err := globalMetaCache.GetPartitionID(ctx, lpt.CollectionName, partitionName)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
partitionIDs = append(partitionIDs, partitionID)
|
|
}
|
|
request := &querypb.LoadPartitionsRequest{
|
|
Base: &commonpb.MsgBase{
|
|
MsgType: commonpb.MsgType_LoadPartitions,
|
|
MsgID: lpt.Base.MsgID,
|
|
Timestamp: lpt.Base.Timestamp,
|
|
SourceID: lpt.Base.SourceID,
|
|
},
|
|
DbID: 0,
|
|
CollectionID: collID,
|
|
PartitionIDs: partitionIDs,
|
|
Schema: collSchema,
|
|
}
|
|
lpt.result, err = lpt.queryService.LoadPartitions(ctx, request)
|
|
return err
|
|
}
|
|
|
|
func (lpt *LoadPartitionTask) PostExecute(ctx context.Context) error {
|
|
return nil
|
|
}
|
|
|
|
type ReleasePartitionTask struct {
|
|
Condition
|
|
*milvuspb.ReleasePartitionsRequest
|
|
ctx context.Context
|
|
queryService types.QueryService
|
|
result *commonpb.Status
|
|
}
|
|
|
|
func (rpt *ReleasePartitionTask) TraceCtx() context.Context {
|
|
return rpt.ctx
|
|
}
|
|
|
|
func (rpt *ReleasePartitionTask) ID() UniqueID {
|
|
return rpt.Base.MsgID
|
|
}
|
|
|
|
func (rpt *ReleasePartitionTask) SetID(uid UniqueID) {
|
|
rpt.Base.MsgID = uid
|
|
}
|
|
|
|
func (rpt *ReleasePartitionTask) Type() commonpb.MsgType {
|
|
return rpt.Base.MsgType
|
|
}
|
|
|
|
func (rpt *ReleasePartitionTask) Name() string {
|
|
return ReleasePartitionTaskName
|
|
}
|
|
|
|
func (rpt *ReleasePartitionTask) BeginTs() Timestamp {
|
|
return rpt.Base.Timestamp
|
|
}
|
|
|
|
func (rpt *ReleasePartitionTask) EndTs() Timestamp {
|
|
return rpt.Base.Timestamp
|
|
}
|
|
|
|
func (rpt *ReleasePartitionTask) SetTs(ts Timestamp) {
|
|
rpt.Base.Timestamp = ts
|
|
}
|
|
|
|
func (rpt *ReleasePartitionTask) OnEnqueue() error {
|
|
rpt.Base = &commonpb.MsgBase{}
|
|
return nil
|
|
}
|
|
|
|
func (rpt *ReleasePartitionTask) PreExecute(ctx context.Context) error {
|
|
rpt.Base.MsgType = commonpb.MsgType_ReleasePartitions
|
|
rpt.Base.SourceID = Params.ProxyID
|
|
|
|
collName := rpt.CollectionName
|
|
|
|
if err := ValidateCollectionName(collName); err != nil {
|
|
return err
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
func (rpt *ReleasePartitionTask) Execute(ctx context.Context) (err error) {
|
|
var partitionIDs []int64
|
|
collID, err := globalMetaCache.GetCollectionID(ctx, rpt.CollectionName)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
for _, partitionName := range rpt.PartitionNames {
|
|
partitionID, err := globalMetaCache.GetPartitionID(ctx, rpt.CollectionName, partitionName)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
partitionIDs = append(partitionIDs, partitionID)
|
|
}
|
|
request := &querypb.ReleasePartitionsRequest{
|
|
Base: &commonpb.MsgBase{
|
|
MsgType: commonpb.MsgType_ReleasePartitions,
|
|
MsgID: rpt.Base.MsgID,
|
|
Timestamp: rpt.Base.Timestamp,
|
|
SourceID: rpt.Base.SourceID,
|
|
},
|
|
DbID: 0,
|
|
CollectionID: collID,
|
|
PartitionIDs: partitionIDs,
|
|
}
|
|
rpt.result, err = rpt.queryService.ReleasePartitions(ctx, request)
|
|
return err
|
|
}
|
|
|
|
func (rpt *ReleasePartitionTask) PostExecute(ctx context.Context) error {
|
|
return nil
|
|
}
|