fix: Fix dispatcher deregister and seek (#40860)

1. Fix deregister dispatcher concurrency. (Keep sure the same logic as
2.5 branch)
2. Fix seek if includeCurrentMsg. (This is only needed by CDC, so
there's no need to pick to 2.5 branch)

issue: issue: https://github.com/milvus-io/milvus/issues/39862

pr: https://github.com/milvus-io/milvus/pull/39863

---------

Signed-off-by: bigsheeper <yihao.dai@zilliz.com>
This commit is contained in:
yihao.dai 2025-03-31 11:38:19 +08:00 committed by GitHub
parent 3ecacc4493
commit d8d1dcf076
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
9 changed files with 148 additions and 91 deletions

View File

@ -110,7 +110,7 @@ func TestClient_Concurrency(t *testing.T) {
// Verify registered targets number.
actual := 0
c.managers.Range(func(pchannel string, manager DispatcherManager) bool {
actual += manager.(*dispatcherManager).registeredTargets.Len()
actual += manager.NumTarget()
return true
})
assert.Equal(t, expected, actual)
@ -120,7 +120,14 @@ func TestClient_Concurrency(t *testing.T) {
actual = 0
c.managers.Range(func(pchannel string, manager DispatcherManager) bool {
m := manager.(*dispatcherManager)
actual += int(m.numActiveTarget.Load())
m.mu.RLock()
defer m.mu.RUnlock()
if m.mainDispatcher != nil {
actual += m.mainDispatcher.targets.Len()
}
for _, d := range m.deputyDispatchers {
actual += d.targets.Len()
}
return true
})
t.Logf("expect = %d, actual = %d\n", expected, actual)
@ -256,9 +263,9 @@ func (suite *SimulationSuite) TestMerge() {
suite.Eventually(func() bool {
for pchannel := range suite.pchannel2Producer {
manager, ok := suite.client.(*client).managers.Get(pchannel)
suite.T().Logf("dispatcherNum = %d, pchannel = %s\n", manager.(*dispatcherManager).numConsumer.Load(), pchannel)
suite.T().Logf("dispatcherNum = %d, pchannel = %s\n", manager.NumConsumer(), pchannel)
suite.True(ok)
if manager.(*dispatcherManager).numConsumer.Load() != 1 { // expected all merged, only mainDispatcher exist
if manager.NumConsumer() != 1 { // expected all merged, only mainDispatcher exist
return false
}
}
@ -323,9 +330,9 @@ func (suite *SimulationSuite) TestSplit() {
suite.Eventually(func() bool {
for pchannel := range suite.pchannel2Producer {
manager, ok := suite.client.(*client).managers.Get(pchannel)
suite.T().Logf("verifing dispatchers merged, dispatcherNum = %d, pchannel = %s\n", manager.(*dispatcherManager).numConsumer.Load(), pchannel)
suite.T().Logf("verifing dispatchers merged, dispatcherNum = %d, pchannel = %s\n", manager.NumConsumer(), pchannel)
suite.True(ok)
if manager.(*dispatcherManager).numConsumer.Load() != 1 { // expected all merged, only mainDispatcher exist
if manager.NumConsumer() != 1 { // expected all merged, only mainDispatcher exist
return false
}
}
@ -371,8 +378,8 @@ func (suite *SimulationSuite) TestSplit() {
manager, ok := suite.client.(*client).managers.Get(pchannel)
suite.True(ok)
suite.T().Logf("verifing split, dispatcherNum = %d, splitNum+1 = %d, pchannel = %s\n",
manager.(*dispatcherManager).numConsumer.Load(), splitNumPerPchannel+1, pchannel)
if manager.(*dispatcherManager).numConsumer.Load() < 1 { // expected 1 mainDispatcher and 1 or more split deputyDispatchers
manager.NumConsumer(), splitNumPerPchannel+1, pchannel)
if manager.NumConsumer() < 1 { // expected 1 mainDispatcher and 1 or more split deputyDispatchers
return false
}
}
@ -393,9 +400,9 @@ func (suite *SimulationSuite) TestSplit() {
suite.Eventually(func() bool {
for pchannel := range suite.pchannel2Producer {
manager, ok := suite.client.(*client).managers.Get(pchannel)
suite.T().Logf("verifing dispatchers merged again, dispatcherNum = %d, pchannel = %s\n", manager.(*dispatcherManager).numConsumer.Load(), pchannel)
suite.T().Logf("verifing dispatchers merged again, dispatcherNum = %d, pchannel = %s\n", manager.NumConsumer(), pchannel)
suite.True(ok)
if manager.(*dispatcherManager).numConsumer.Load() != 1 { // expected all merged, only mainDispatcher exist
if manager.NumConsumer() != 1 { // expected all merged, only mainDispatcher exist
return false
}
}

View File

@ -87,6 +87,7 @@ func NewDispatcher(
pchannel string,
position *Pos,
subPos SubPos,
includeCurrentMsg bool,
pullbackEndTs typeutil.Timestamp,
) (*Dispatcher, error) {
subName := fmt.Sprintf("%s-%d-%d", pchannel, id, time.Now().UnixNano())
@ -116,7 +117,7 @@ func NewDispatcher(
return nil, err
}
log.Info("as consumer done", zap.Any("position", position))
err = stream.Seek(ctx, []*Pos{position}, false)
err = stream.Seek(ctx, []*Pos{position}, includeCurrentMsg)
if err != nil {
log.Error("seek failed", zap.Error(err))
return nil, err

View File

@ -36,7 +36,7 @@ func TestDispatcher(t *testing.T) {
ctx := context.Background()
t.Run("test base", func(t *testing.T) {
d, err := NewDispatcher(ctx, newMockFactory(), time.Now().UnixNano(), "mock_pchannel_0",
nil, common.SubscriptionPositionEarliest, 0)
nil, common.SubscriptionPositionEarliest, false, 0)
assert.NoError(t, err)
assert.NotPanics(t, func() {
d.Handle(start)
@ -65,7 +65,7 @@ func TestDispatcher(t *testing.T) {
},
}
d, err := NewDispatcher(ctx, factory, time.Now().UnixNano(), "mock_pchannel_0",
nil, common.SubscriptionPositionEarliest, 0)
nil, common.SubscriptionPositionEarliest, false, 0)
assert.Error(t, err)
assert.Nil(t, d)
@ -73,7 +73,7 @@ func TestDispatcher(t *testing.T) {
t.Run("test target", func(t *testing.T) {
d, err := NewDispatcher(ctx, newMockFactory(), time.Now().UnixNano(), "mock_pchannel_0",
nil, common.SubscriptionPositionEarliest, 0)
nil, common.SubscriptionPositionEarliest, false, 0)
assert.NoError(t, err)
output := make(chan *msgstream.MsgPack, 1024)
@ -128,7 +128,7 @@ func TestDispatcher(t *testing.T) {
func BenchmarkDispatcher_handle(b *testing.B) {
d, err := NewDispatcher(context.Background(), newMockFactory(), time.Now().UnixNano(), "mock_pchannel_0",
nil, common.SubscriptionPositionEarliest, 0)
nil, common.SubscriptionPositionEarliest, false, 0)
assert.NoError(b, err)
for i := 0; i < b.N; i++ {
@ -143,7 +143,7 @@ func BenchmarkDispatcher_handle(b *testing.B) {
func TestGroupMessage(t *testing.T) {
d, err := NewDispatcher(context.Background(), newMockFactory(), time.Now().UnixNano(), "mock_pchannel_0",
nil, common.SubscriptionPositionEarliest, 0)
nil, common.SubscriptionPositionEarliest, false, 0)
assert.NoError(t, err)
d.AddTarget(newTarget(&StreamConfig{VChannel: "mock_pchannel_0_1v0"}))
d.AddTarget(newTarget(&StreamConfig{

View File

@ -40,6 +40,8 @@ import (
type DispatcherManager interface {
Add(ctx context.Context, streamConfig *StreamConfig) (<-chan *MsgPack, error)
Remove(vchannel string)
NumTarget() int
NumConsumer() int
Run()
Close()
}
@ -53,9 +55,7 @@ type dispatcherManager struct {
registeredTargets *typeutil.ConcurrentMap[string, *target]
numConsumer atomic.Int64
numActiveTarget atomic.Int64
mu sync.RWMutex
mainDispatcher *Dispatcher
deputyDispatchers map[int64]*Dispatcher // ID -> *Dispatcher
@ -96,9 +96,26 @@ func (c *dispatcherManager) Remove(vchannel string) {
zap.Int64("nodeID", c.nodeID), zap.String("vchannel", vchannel))
return
}
c.removeTargetFromDispatcher(t)
t.close()
}
func (c *dispatcherManager) NumTarget() int {
return c.registeredTargets.Len()
}
func (c *dispatcherManager) NumConsumer() int {
c.mu.RLock()
defer c.mu.RUnlock()
numConsumer := 0
if c.mainDispatcher != nil {
numConsumer++
}
numConsumer += len(c.deputyDispatchers)
return numConsumer
}
func (c *dispatcherManager) Close() {
c.closeOnce.Do(func() {
c.closeChan <- struct{}{}
@ -123,30 +140,46 @@ func (c *dispatcherManager) Run() {
c.tryRemoveUnregisteredTargets()
c.tryBuildDispatcher()
c.tryMerge()
c.updateNumInfo()
}
}
}
func (c *dispatcherManager) updateNumInfo() {
numConsumer := 0
numActiveTarget := 0
func (c *dispatcherManager) removeTargetFromDispatcher(t *target) {
log := log.With(zap.String("role", c.role), zap.Int64("nodeID", c.nodeID), zap.String("pchannel", c.pchannel))
c.mu.Lock()
defer c.mu.Unlock()
for _, dispatcher := range c.deputyDispatchers {
if dispatcher.HasTarget(t.vchannel) {
dispatcher.Handle(pause)
dispatcher.RemoveTarget(t.vchannel)
if dispatcher.TargetNum() == 0 {
dispatcher.Handle(terminate)
delete(c.deputyDispatchers, dispatcher.ID())
log.Info("remove deputy dispatcher done", zap.Int64("id", dispatcher.ID()))
} else {
dispatcher.Handle(resume)
}
t.close()
}
}
if c.mainDispatcher != nil {
numConsumer++
numActiveTarget += c.mainDispatcher.TargetNum()
if c.mainDispatcher.HasTarget(t.vchannel) {
c.mainDispatcher.Handle(pause)
c.mainDispatcher.RemoveTarget(t.vchannel)
if c.mainDispatcher.TargetNum() == 0 && len(c.deputyDispatchers) == 0 {
c.mainDispatcher.Handle(terminate)
c.mainDispatcher = nil
} else {
c.mainDispatcher.Handle(resume)
}
t.close()
}
}
numConsumer += len(c.deputyDispatchers)
c.numConsumer.Store(int64(numConsumer))
for _, d := range c.deputyDispatchers {
numActiveTarget += d.TargetNum()
}
c.numActiveTarget.Store(int64(numActiveTarget))
}
func (c *dispatcherManager) tryRemoveUnregisteredTargets() {
log := log.With(zap.String("role", c.role), zap.Int64("nodeID", c.nodeID), zap.String("pchannel", c.pchannel))
unregisteredTargets := make([]*target, 0)
c.mu.RLock()
for _, dispatcher := range c.deputyDispatchers {
for _, t := range dispatcher.GetTargets() {
if !c.registeredTargets.Contain(t.vchannel) {
@ -161,36 +194,10 @@ func (c *dispatcherManager) tryRemoveUnregisteredTargets() {
}
}
}
for _, dispatcher := range c.deputyDispatchers {
for _, t := range unregisteredTargets {
if dispatcher.HasTarget(t.vchannel) {
dispatcher.Handle(pause)
dispatcher.RemoveTarget(t.vchannel)
if dispatcher.TargetNum() == 0 {
dispatcher.Handle(terminate)
delete(c.deputyDispatchers, dispatcher.ID())
log.Info("remove deputy dispatcher done", zap.Int64("id", dispatcher.ID()))
} else {
dispatcher.Handle(resume)
}
t.close()
}
}
}
if c.mainDispatcher != nil {
for _, t := range unregisteredTargets {
if c.mainDispatcher.HasTarget(t.vchannel) {
c.mainDispatcher.Handle(pause)
c.mainDispatcher.RemoveTarget(t.vchannel)
if c.mainDispatcher.TargetNum() == 0 && len(c.deputyDispatchers) == 0 {
c.mainDispatcher.Handle(terminate)
c.mainDispatcher = nil
} else {
c.mainDispatcher.Handle(resume)
}
t.close()
}
}
c.mu.RUnlock()
for _, t := range unregisteredTargets {
c.removeTargetFromDispatcher(t)
}
}
@ -202,6 +209,7 @@ func (c *dispatcherManager) tryBuildDispatcher() {
// get lack targets to perform subscription
lackTargets := make([]*target, 0, len(allTargets))
c.mu.RLock()
OUTER:
for _, t := range allTargets {
if c.mainDispatcher != nil && c.mainDispatcher.HasTarget(t.vchannel) {
@ -214,6 +222,7 @@ OUTER:
}
lackTargets = append(lackTargets, t)
}
c.mu.RUnlock()
if len(lackTargets) == 0 {
return
@ -235,6 +244,19 @@ OUTER:
}
}
// For CDC, CDC needs to includeCurrentMsg when create new dispatcher
// and NOT includeCurrentMsg when create lag dispatcher. So if any dispatcher lagged,
// we give up batch subscription and create dispatcher for only one target.
includeCurrentMsg := false
for _, candidate := range candidateTargets {
if candidate.isLagged {
candidateTargets = []*target{candidate}
includeCurrentMsg = true
candidate.isLagged = false
break
}
}
vchannels := lo.Map(candidateTargets, func(t *target, _ int) string {
return t.vchannel
})
@ -247,7 +269,7 @@ OUTER:
// TODO: add newDispatcher timeout param and init context
id := c.idAllocator.Inc()
d, err := NewDispatcher(context.Background(), c.factory, id, c.pchannel, earliestTarget.pos, earliestTarget.subPos, latestTarget.pos.GetTimestamp())
d, err := NewDispatcher(context.Background(), c.factory, id, c.pchannel, earliestTarget.pos, earliestTarget.subPos, includeCurrentMsg, latestTarget.pos.GetTimestamp())
if err != nil {
panic(err)
}
@ -281,6 +303,21 @@ OUTER:
zap.Strings("vchannels", vchannels),
)
c.mu.Lock()
defer c.mu.Unlock()
d.Handle(pause)
for _, candidate := range candidateTargets {
vchannel := candidate.vchannel
t, ok := c.registeredTargets.Get(vchannel)
// During the build process, the target may undergo repeated deregister and register,
// causing the channel object to change. Here, validate whether the channel is the
// same as before the build. If inconsistent, remove the target.
if !ok || t.ch != candidate.ch {
d.RemoveTarget(vchannel)
}
}
d.Handle(resume)
if c.mainDispatcher == nil {
c.mainDispatcher = d
log.Info("add main dispatcher", zap.Int64("id", d.ID()))
@ -291,6 +328,9 @@ OUTER:
}
func (c *dispatcherManager) tryMerge() {
c.mu.Lock()
defer c.mu.Unlock()
start := time.Now()
log := log.With(zap.String("role", c.role), zap.Int64("nodeID", c.nodeID), zap.String("pchannel", c.pchannel))
@ -352,6 +392,9 @@ func (c *dispatcherManager) deleteMetric(channel string) {
}
func (c *dispatcherManager) uploadMetric() {
c.mu.RLock()
defer c.mu.RUnlock()
nodeIDStr := fmt.Sprintf("%d", c.nodeID)
fn := func(gauge *prometheus.GaugeVec) {
if c.mainDispatcher == nil {

View File

@ -50,8 +50,8 @@ func TestManager(t *testing.T) {
assert.NotNil(t, c)
go c.Run()
defer c.Close()
assert.Equal(t, int64(0), c.(*dispatcherManager).numConsumer.Load())
assert.Equal(t, 0, c.(*dispatcherManager).registeredTargets.Len())
assert.Equal(t, 0, c.NumConsumer())
assert.Equal(t, 0, c.NumTarget())
var offset int
for i := 0; i < 30; i++ {
@ -64,8 +64,8 @@ func TestManager(t *testing.T) {
assert.NoError(t, err)
}
assert.Eventually(t, func() bool {
t.Logf("offset=%d, numConsumer=%d, numTarget=%d", offset, c.(*dispatcherManager).numConsumer.Load(), c.(*dispatcherManager).registeredTargets.Len())
return c.(*dispatcherManager).registeredTargets.Len() == offset
t.Logf("offset=%d, numConsumer=%d, numTarget=%d", offset, c.NumConsumer(), c.NumTarget())
return c.NumTarget() == offset
}, 3*time.Second, 10*time.Millisecond)
for j := 0; j < rand.Intn(r); j++ {
vchannel := fmt.Sprintf("%s_vchannelv%d", pchannel, offset)
@ -74,8 +74,8 @@ func TestManager(t *testing.T) {
offset--
}
assert.Eventually(t, func() bool {
t.Logf("offset=%d, numConsumer=%d, numTarget=%d", offset, c.(*dispatcherManager).numConsumer.Load(), c.(*dispatcherManager).registeredTargets.Len())
return c.(*dispatcherManager).registeredTargets.Len() == offset
t.Logf("offset=%d, numConsumer=%d, numTarget=%d", offset, c.NumConsumer(), c.NumTarget())
return c.NumTarget() == offset
}, 3*time.Second, 10*time.Millisecond)
}
})
@ -108,7 +108,7 @@ func TestManager(t *testing.T) {
assert.NoError(t, err)
o2, err := c.Add(ctx, NewStreamConfig(fmt.Sprintf("%s_vchannel-2", pchannel), nil, common.SubscriptionPositionUnknown))
assert.NoError(t, err)
assert.Equal(t, 3, c.(*dispatcherManager).registeredTargets.Len())
assert.Equal(t, 3, c.NumTarget())
consumeFn := func(output <-chan *MsgPack, done <-chan struct{}, wg *sync.WaitGroup) {
defer wg.Done()
@ -130,14 +130,14 @@ func TestManager(t *testing.T) {
go consumeFn(o2, d2, wg)
assert.Eventually(t, func() bool {
return c.(*dispatcherManager).numConsumer.Load() == 1 // expected merge
return c.NumConsumer() == 1 // expected merge
}, 20*time.Second, 10*time.Millisecond)
// stop consume vchannel_2 to trigger split
d2 <- struct{}{}
assert.Eventually(t, func() bool {
t.Logf("c.NumConsumer=%d", c.(*dispatcherManager).numConsumer.Load())
return c.(*dispatcherManager).numConsumer.Load() == 2 // expected split
t.Logf("c.NumConsumer=%d", c.NumConsumer())
return c.NumConsumer() == 2 // expected split
}, 20*time.Second, 10*time.Millisecond)
// stop all
@ -169,9 +169,9 @@ func TestManager(t *testing.T) {
assert.NoError(t, err)
_, err = c.Add(ctx, NewStreamConfig(fmt.Sprintf("%s_vchannel-2", pchannel), nil, common.SubscriptionPositionUnknown))
assert.NoError(t, err)
assert.Equal(t, 3, c.(*dispatcherManager).registeredTargets.Len())
assert.Equal(t, 3, c.NumTarget())
assert.Eventually(t, func() bool {
return c.(*dispatcherManager).numConsumer.Load() >= 1
return c.NumConsumer() >= 1
}, 3*time.Second, 10*time.Millisecond)
c.(*dispatcherManager).mainDispatcher.curTs.Store(1000)
for _, d := range c.(*dispatcherManager).deputyDispatchers {
@ -183,9 +183,9 @@ func TestManager(t *testing.T) {
defer paramtable.Get().Reset(checkIntervalK)
assert.Eventually(t, func() bool {
return c.(*dispatcherManager).numConsumer.Load() == 1 // expected merged
return c.(*dispatcherManager).NumConsumer() == 1 // expected merged
}, 3*time.Second, 10*time.Millisecond)
assert.Equal(t, 3, c.(*dispatcherManager).registeredTargets.Len())
assert.Equal(t, 3, c.NumTarget())
})
t.Run("test_repeated_vchannel", func(t *testing.T) {
@ -220,7 +220,7 @@ func TestManager(t *testing.T) {
assert.Error(t, err)
assert.Eventually(t, func() bool {
return c.(*dispatcherManager).numConsumer.Load() >= 1
return c.NumConsumer() >= 1
}, 3*time.Second, 10*time.Millisecond)
})
}

View File

@ -34,6 +34,7 @@ type target struct {
ch chan *MsgPack
subPos SubPos
pos *Pos
isLagged bool
closeMu sync.Mutex
closeOnce sync.Once
@ -75,6 +76,7 @@ func (t *target) close() {
t.closed = true
t.timer.Stop()
close(t.ch)
log.Info("close target chan", zap.String("vchannel", t.vchannel))
})
}
@ -97,6 +99,7 @@ func (t *target) send(pack *MsgPack) error {
log.Info("target closed", zap.String("vchannel", t.vchannel))
return nil
case <-t.timer.C:
t.isLagged = true
return fmt.Errorf("send target timeout, vchannel=%s, timeout=%s, beginTs=%d, endTs=%d", t.vchannel, t.maxLag, pack.BeginTs, pack.EndTs)
case t.ch <- pack:
return nil

View File

@ -474,7 +474,7 @@ func TestSearchGroupByUnsupportedDataType(t *testing.T) {
common.DefaultFloatFieldName, common.DefaultDoubleFieldName,
common.DefaultJSONFieldName, common.DefaultFloatVecFieldName, common.DefaultInt8ArrayField, common.DefaultFloatArrayField,
} {
_, err := mc.Search(ctx, client.NewSearchOption(collName, common.DefaultLimit, queryVec).WithGroupByField(unsupportedField).WithANNSField(common.DefaultFloatVecFieldName))
_, err := mc.Search(ctx, client.NewSearchOption(collName, common.DefaultLimit, queryVec).WithGroupByField(unsupportedField).WithANNSField(common.DefaultFloatVecFieldName).WithConsistencyLevel(entity.ClStrong))
common.CheckErr(t, err, false, "unsupported data type")
}
}
@ -495,7 +495,7 @@ func TestSearchGroupByRangeSearch(t *testing.T) {
// range search
_, err := mc.Search(ctx, client.NewSearchOption(collName, common.DefaultLimit, queryVec).WithGroupByField(common.DefaultVarcharFieldName).
WithANNSField(common.DefaultFloatVecFieldName).WithSearchParam("radius", "0").WithSearchParam("range_filter", "0.8"))
WithANNSField(common.DefaultFloatVecFieldName).WithSearchParam("radius", "0").WithSearchParam("range_filter", "0.8").WithConsistencyLevel(entity.ClStrong))
common.CheckErr(t, err, false, "Not allowed to do range-search when doing search-group-by")
}

View File

@ -268,7 +268,7 @@ func TestHybridSearchMultiVectorsPagination(t *testing.T) {
// offset 0, -1 -> 0
for _, offset := range []int{0, -1} {
searchRes, err := mc.HybridSearch(ctx, client.NewHybridSearchOption(schema.CollectionName, common.DefaultLimit, annReqDef).WithOffset(offset))
searchRes, err := mc.HybridSearch(ctx, client.NewHybridSearchOption(schema.CollectionName, common.DefaultLimit, annReqDef).WithOffset(offset).WithConsistencyLevel(entity.ClStrong))
common.CheckErr(t, err, true)
common.CheckSearchResult(t, searchRes, common.DefaultNq, common.DefaultLimit)
}

View File

@ -65,14 +65,14 @@ func TestQueryVarcharPkDefault(t *testing.T) {
// query
expr := fmt.Sprintf("%s in ['0', '1', '2', '3', '4']", common.DefaultVarcharFieldName)
queryRes, err := mc.Query(ctx, client.NewQueryOption(schema.CollectionName).WithFilter(expr))
queryRes, err := mc.Query(ctx, client.NewQueryOption(schema.CollectionName).WithFilter(expr).WithConsistencyLevel(entity.ClStrong))
common.CheckErr(t, err, true)
common.CheckQueryResult(t, queryRes.Fields, []column.Column{insertRes.IDs.Slice(0, 5)})
// get ids -> same result with query
varcharValues := []string{"0", "1", "2", "3", "4"}
ids := column.NewColumnVarChar(common.DefaultVarcharFieldName, varcharValues)
getRes, errGet := mc.Get(ctx, client.NewQueryOption(schema.CollectionName).WithIDs(ids))
getRes, errGet := mc.Get(ctx, client.NewQueryOption(schema.CollectionName).WithIDs(ids).WithConsistencyLevel(entity.ClStrong))
common.CheckErr(t, errGet, true)
common.CheckQueryResult(t, getRes.Fields, []column.Column{insertRes.IDs.Slice(0, 5)})
}
@ -1094,12 +1094,12 @@ func TestQueryWithTemplateParam(t *testing.T) {
}
// default
queryRes, err := mc.Query(ctx, client.NewQueryOption(schema.CollectionName).
WithFilter(fmt.Sprintf("%s in {int64Values}", common.DefaultInt64FieldName)).WithTemplateParam("int64Values", int64Values))
WithFilter(fmt.Sprintf("%s in {int64Values}", common.DefaultInt64FieldName)).WithTemplateParam("int64Values", int64Values).WithConsistencyLevel(entity.ClStrong))
common.CheckErr(t, err, true)
common.CheckQueryResult(t, queryRes.Fields, []column.Column{column.NewColumnInt64(common.DefaultInt64FieldName, int64Values)})
// cover keys
res, err := mc.Query(ctx, client.NewQueryOption(schema.CollectionName).WithFilter("int64 < {k2}").WithTemplateParam("k2", 10).WithTemplateParam("k2", 5))
res, err := mc.Query(ctx, client.NewQueryOption(schema.CollectionName).WithFilter("int64 < {k2}").WithTemplateParam("k2", 10).WithTemplateParam("k2", 5).WithConsistencyLevel(entity.ClStrong))
common.CheckErr(t, err, true)
require.Equal(t, 5, res.ResultCount)
@ -1107,14 +1107,14 @@ func TestQueryWithTemplateParam(t *testing.T) {
anyValues := []int64{0.0, 100.0, 10000.0}
countRes, err := mc.Query(ctx, client.NewQueryOption(schema.CollectionName).
WithFilter(fmt.Sprintf("json_contains_any (%s, {any_values})", common.DefaultFloatArrayField)).WithTemplateParam("any_values", anyValues).
WithOutputFields(common.QueryCountFieldName))
WithOutputFields(common.QueryCountFieldName).WithConsistencyLevel(entity.ClStrong))
common.CheckErr(t, err, true)
count, _ := countRes.Fields[0].GetAsInt64(0)
require.EqualValues(t, 101, count)
// dynamic
countRes, err = mc.Query(ctx, client.NewQueryOption(schema.CollectionName).
WithFilter("dynamicNumber % 2 == {v}").WithTemplateParam("v", 0).WithOutputFields(common.QueryCountFieldName))
WithFilter("dynamicNumber % 2 == {v}").WithTemplateParam("v", 0).WithOutputFields(common.QueryCountFieldName).WithConsistencyLevel(entity.ClStrong))
common.CheckErr(t, err, true)
count, _ = countRes.Fields[0].GetAsInt64(0)
require.EqualValues(t, 1500, count)
@ -1123,7 +1123,8 @@ func TestQueryWithTemplateParam(t *testing.T) {
countRes, err = mc.Query(ctx, client.NewQueryOption(schema.CollectionName).
WithFilter(fmt.Sprintf("%s['bool'] == {v}", common.DefaultJSONFieldName)).
WithTemplateParam("v", false).
WithOutputFields(common.QueryCountFieldName))
WithOutputFields(common.QueryCountFieldName).
WithConsistencyLevel(entity.ClStrong))
common.CheckErr(t, err, true)
count, _ = countRes.Fields[0].GetAsInt64(0)
require.EqualValues(t, 1500/2, count)
@ -1132,7 +1133,8 @@ func TestQueryWithTemplateParam(t *testing.T) {
countRes, err = mc.Query(ctx, client.NewQueryOption(schema.CollectionName).
WithFilter(fmt.Sprintf("%s == {v}", common.DefaultBoolFieldName)).
WithTemplateParam("v", true).
WithOutputFields(common.QueryCountFieldName))
WithOutputFields(common.QueryCountFieldName).
WithConsistencyLevel(entity.ClStrong))
common.CheckErr(t, err, true)
count, _ = countRes.Fields[0].GetAsInt64(0)
require.EqualValues(t, common.DefaultNb/2, count)
@ -1141,7 +1143,8 @@ func TestQueryWithTemplateParam(t *testing.T) {
res, err = mc.Query(ctx, client.NewQueryOption(schema.CollectionName).
WithFilter(fmt.Sprintf("%s >= {k1} && %s < {k2}", common.DefaultInt64FieldName, common.DefaultInt64FieldName)).
WithTemplateParam("v", 0).WithTemplateParam("k1", 1000).
WithTemplateParam("k2", 2000))
WithTemplateParam("k2", 2000).
WithConsistencyLevel(entity.ClStrong))
common.CheckErr(t, err, true)
require.EqualValues(t, 1000, res.ResultCount)
}