diff --git a/tests/integration/replicas/load/load_test.go b/tests/integration/replicas/load/load_test.go index bf9fabb284..32ea069bc8 100644 --- a/tests/integration/replicas/load/load_test.go +++ b/tests/integration/replicas/load/load_test.go @@ -30,6 +30,7 @@ import ( "github.com/milvus-io/milvus-proto/go-api/v2/commonpb" "github.com/milvus-io/milvus-proto/go-api/v2/milvuspb" "github.com/milvus-io/milvus-proto/go-api/v2/rgpb" + "github.com/milvus-io/milvus-proto/go-api/v2/schemapb" "github.com/milvus-io/milvus/internal/querycoordv2/meta" "github.com/milvus-io/milvus/pkg/v2/common" "github.com/milvus-io/milvus/pkg/v2/log" @@ -747,13 +748,23 @@ func (s *LoadTestSuite) TestLoadWithCompact() { // Start a goroutine to continuously insert data and trigger compaction go func() { defer wg.Done() + nextStartPK := int64(1) for { select { case <-stopInsertCh: return default: - s.InsertAndFlush(ctx, dbName, collName, 2000, dim) - _, err := s.Cluster.MilvusClient.ManualCompaction(ctx, &milvuspb.ManualCompactionRequest{ + var err error + nextStartPK, err = s.InsertAndFlush(ctx, dbName, collName, 2000, dim, &integration.PrimaryKeyConfig{ + FieldName: integration.Int64Field, + FieldType: schemapb.DataType_Int64, + NumChannels: 1, + StartPK: nextStartPK, + }) + if err != nil { + return + } + _, err = s.Cluster.MilvusClient.ManualCompaction(ctx, &milvuspb.ManualCompactionRequest{ CollectionName: collName, }) s.NoError(err) diff --git a/tests/integration/util_collection.go b/tests/integration/util_collection.go index 7013df650c..fe979db84a 100644 --- a/tests/integration/util_collection.go +++ b/tests/integration/util_collection.go @@ -29,21 +29,34 @@ type CreateCollectionConfig struct { ResourceGroups []string } -func (s *MiniClusterSuite) InsertAndFlush(ctx context.Context, dbName, collectionName string, rowNum, dim int) error { +type PrimaryKeyConfig struct { + FieldName string + FieldType schemapb.DataType + NumChannels int + StartPK int64 // Starting PK value (default 1 if not specified) +} + +// InsertAndFlush inserts data and flushes. +// Returns the next startPK for subsequent calls and any error. +func (s *MiniClusterSuite) InsertAndFlush(ctx context.Context, dbName, collectionName string, rowNum, dim int, pkConfig *PrimaryKeyConfig) (int64, error) { + startPK := pkConfig.StartPK + if startPK == 0 { + startPK = 1 // Default to 1 if not specified + } + + pkColumn, nextPK := GenerateChannelBalancedPrimaryKeys(pkConfig.FieldName, pkConfig.FieldType, rowNum, pkConfig.NumChannels, startPK) fVecColumn := NewFloatVectorFieldData(FloatVecField, rowNum, dim) - hashKeys := GenerateHashKeys(rowNum) insertResult, err := s.Cluster.MilvusClient.Insert(ctx, &milvuspb.InsertRequest{ DbName: dbName, CollectionName: collectionName, - FieldsData: []*schemapb.FieldData{fVecColumn}, - HashKeys: hashKeys, + FieldsData: []*schemapb.FieldData{pkColumn, fVecColumn}, NumRows: uint32(rowNum), }) if err != nil { - return err + return 0, err } if !merr.Ok(insertResult.Status) { - return merr.Error(insertResult.Status) + return 0, merr.Error(insertResult.Status) } flushResp, err := s.Cluster.MilvusClient.Flush(ctx, &milvuspb.FlushRequest{ @@ -51,11 +64,11 @@ func (s *MiniClusterSuite) InsertAndFlush(ctx context.Context, dbName, collectio CollectionNames: []string{collectionName}, }) if err := merr.CheckRPCCall(flushResp.GetStatus(), err); err != nil { - return err + return 0, err } segmentIDs, has := flushResp.GetCollSegIDs()[collectionName] if !has || segmentIDs == nil { - return merr.Error(&commonpb.Status{ + return 0, merr.Error(&commonpb.Status{ ErrorCode: commonpb.ErrorCode_IllegalArgument, Reason: "failed to get segment IDs", }) @@ -63,17 +76,17 @@ func (s *MiniClusterSuite) InsertAndFlush(ctx context.Context, dbName, collectio ids := segmentIDs.GetData() flushTs, has := flushResp.GetCollFlushTs()[collectionName] if !has { - return merr.Error(&commonpb.Status{ + return 0, merr.Error(&commonpb.Status{ ErrorCode: commonpb.ErrorCode_IllegalArgument, Reason: "failed to get flush timestamp", }) } s.WaitForFlush(ctx, ids, flushTs, dbName, collectionName) - return nil + return nextPK, nil } func (s *MiniClusterSuite) CreateCollectionWithConfiguration(ctx context.Context, cfg *CreateCollectionConfig) { - schema := ConstructSchema(cfg.CollectionName, cfg.Dim, true) + schema := ConstructSchema(cfg.CollectionName, cfg.Dim, false) s.CreateCollection(ctx, cfg, schema) } @@ -107,8 +120,15 @@ func (s *MiniClusterSuite) CreateCollection(ctx context.Context, cfg *CreateColl s.True(merr.Ok(showCollectionsResp.Status)) log.Info("ShowCollections result", zap.Any("showCollectionsResp", showCollectionsResp)) + nextStartPK := int64(1) for i := 0; i < cfg.SegmentNum; i++ { - err = s.InsertAndFlush(ctx, cfg.DBName, cfg.CollectionName, cfg.RowNumPerSegment, cfg.Dim) + var err error + nextStartPK, err = s.InsertAndFlush(ctx, cfg.DBName, cfg.CollectionName, cfg.RowNumPerSegment, cfg.Dim, &PrimaryKeyConfig{ + FieldName: Int64Field, + FieldType: schemapb.DataType_Int64, + NumChannels: cfg.ChannelNum, + StartPK: nextStartPK, + }) s.NoError(err) } diff --git a/tests/integration/util_insert.go b/tests/integration/util_insert.go index 769b27a6b5..2fad0e9111 100644 --- a/tests/integration/util_insert.go +++ b/tests/integration/util_insert.go @@ -18,8 +18,13 @@ package integration import ( "context" + "encoding/binary" + "fmt" + "hash/crc32" "time" + "github.com/spaolacci/murmur3" + "github.com/milvus-io/milvus-proto/go-api/v2/milvuspb" "github.com/milvus-io/milvus-proto/go-api/v2/schemapb" "github.com/milvus-io/milvus/pkg/v2/util/testutils" @@ -242,3 +247,193 @@ func GenerateSparseFloatArray(numRows int) *schemapb.SparseFloatArray { func GenerateHashKeys(numRows int) []uint32 { return testutils.GenerateHashKeys(numRows) } + +// GenerateChannelBalancedPrimaryKeys generates primary keys that are evenly distributed across channels. +// It supports both Int64 and VarChar primary key types. +// For Int64: uses murmur3 hash (same as typeutil.Hash32Int64) +// For VarChar: uses crc32 hash (same as typeutil.HashString2Uint32) +// startPK specifies where to begin searching for PKs. +// Returns the FieldData and the next startPK for subsequent calls. +func GenerateChannelBalancedPrimaryKeys(fieldName string, fieldType schemapb.DataType, numRows int, numChannels int, startPK int64) (*schemapb.FieldData, int64) { + switch fieldType { + case schemapb.DataType_Int64: + pks, nextPK := GenerateBalancedInt64PKs(numRows, numChannels, startPK) + return &schemapb.FieldData{ + Type: schemapb.DataType_Int64, + FieldName: fieldName, + Field: &schemapb.FieldData_Scalars{ + Scalars: &schemapb.ScalarField{ + Data: &schemapb.ScalarField_LongData{ + LongData: &schemapb.LongArray{ + Data: pks, + }, + }, + }, + }, + }, nextPK + case schemapb.DataType_VarChar, schemapb.DataType_String: + pks, nextIndex := GenerateBalancedVarCharPKs(numRows, numChannels, int(startPK)) + return &schemapb.FieldData{ + Type: schemapb.DataType_VarChar, + FieldName: fieldName, + Field: &schemapb.FieldData_Scalars{ + Scalars: &schemapb.ScalarField{ + Data: &schemapb.ScalarField_StringData{ + StringData: &schemapb.StringArray{ + Data: pks, + }, + }, + }, + }, + }, int64(nextIndex) + default: + panic(fmt.Sprintf("not supported primary key type: %s", fieldType)) + } +} + +// GenerateBalancedInt64PKs generates int64 primary keys that are evenly distributed across channels. +// This ensures each channel receives exactly numRows/numChannels items based on PK hash values. +// The function searches for PKs that hash to each channel to achieve exact distribution. +// startPK specifies where to begin searching for PKs. +// Returns the generated PKs and the next startPK for subsequent calls. +func GenerateBalancedInt64PKs(numRows int, numChannels int, startPK int64) ([]int64, int64) { + if numChannels <= 0 { + numChannels = 1 + } + + // Calculate how many items each channel should receive + baseCount := numRows / numChannels + remainder := numRows % numChannels + + // Collect PKs for each channel + channelPKs := make([][]int64, numChannels) + targetCounts := make([]int, numChannels) + for ch := 0; ch < numChannels; ch++ { + targetCounts[ch] = baseCount + if ch < remainder { + targetCounts[ch]++ + } + channelPKs[ch] = make([]int64, 0, targetCounts[ch]) + } + + // Search for PKs that hash to each channel + var lastPK int64 + for pk := startPK; ; pk++ { + lastPK = pk + // Calculate which channel this PK would go to + hash := hashInt64ForChannel(pk) + ch := int(hash % uint32(numChannels)) + + if len(channelPKs[ch]) < targetCounts[ch] { + channelPKs[ch] = append(channelPKs[ch], pk) + + // Check if all channels have enough PKs + done := true + for ch := 0; ch < numChannels; ch++ { + if len(channelPKs[ch]) < targetCounts[ch] { + done = false + break + } + } + if done { + break + } + } + } + + // Combine all PKs + result := make([]int64, 0, numRows) + for ch := 0; ch < numChannels; ch++ { + result = append(result, channelPKs[ch]...) + } + + return result, lastPK + 1 +} + +// hashInt64ForChannel computes the hash value for channel assignment. +// This mirrors the logic in typeutil.Hash32Int64 and HashPK2Channels. +func hashInt64ForChannel(v int64) uint32 { + // Must match the behavior of typeutil.Hash32Int64 + // which uses common.Endian (binary.LittleEndian) + b := make([]byte, 8) + binary.LittleEndian.PutUint64(b, uint64(v)) + + // Use murmur3 hash (same as typeutil.Hash32Bytes) + h := murmur3.New32() + h.Write(b) + return h.Sum32() & 0x7fffffff +} + +// GenerateBalancedVarCharPKs generates varchar primary keys that are evenly distributed across channels. +// This ensures each channel receives exactly numRows/numChannels items based on PK hash values. +// The function searches for PKs that hash to each channel to achieve exact distribution. +// startIndex specifies where to begin searching for PKs (used in "pk_" format). +// Returns the generated PKs and the next startIndex for subsequent calls. +func GenerateBalancedVarCharPKs(numRows int, numChannels int, startIndex int) ([]string, int) { + if numChannels <= 0 { + numChannels = 1 + } + + // Calculate how many items each channel should receive + baseCount := numRows / numChannels + remainder := numRows % numChannels + + // Collect PKs for each channel + channelPKs := make([][]string, numChannels) + targetCounts := make([]int, numChannels) + for ch := 0; ch < numChannels; ch++ { + targetCounts[ch] = baseCount + if ch < remainder { + targetCounts[ch]++ + } + channelPKs[ch] = make([]string, 0, targetCounts[ch]) + } + + // Search for PKs that hash to each channel + var lastIndex int + for i := startIndex; ; i++ { + lastIndex = i + // Generate a unique string PK + pk := fmt.Sprintf("pk_%d", i) + + // Calculate which channel this PK would go to + hash := hashVarCharForChannel(pk) + ch := int(hash % uint32(numChannels)) + + if len(channelPKs[ch]) < targetCounts[ch] { + channelPKs[ch] = append(channelPKs[ch], pk) + + // Check if all channels have enough PKs + done := true + for ch := 0; ch < numChannels; ch++ { + if len(channelPKs[ch]) < targetCounts[ch] { + done = false + break + } + } + if done { + break + } + } + } + + // Combine all PKs + result := make([]string, 0, numRows) + for ch := 0; ch < numChannels; ch++ { + result = append(result, channelPKs[ch]...) + } + + return result, lastIndex + 1 +} + +// hashVarCharForChannel computes the hash value for channel assignment of varchar PKs. +// This mirrors the logic in typeutil.HashString2Uint32 and HashPK2Channels. +func hashVarCharForChannel(v string) uint32 { + // Must match the behavior of typeutil.HashString2Uint32 + // which uses crc32.ChecksumIEEE with substring limit of 100 chars + subString := v + if len(v) > 100 { + subString = v[:100] + } + return crc32.ChecksumIEEE([]byte(subString)) +} diff --git a/tests/integration/util_insert_test.go b/tests/integration/util_insert_test.go new file mode 100644 index 0000000000..6fbf80f8c0 --- /dev/null +++ b/tests/integration/util_insert_test.go @@ -0,0 +1,730 @@ +// Licensed to the LF AI & Data foundation under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package integration + +import ( + "fmt" + "testing" + + "github.com/stretchr/testify/assert" + + "github.com/milvus-io/milvus-proto/go-api/v2/schemapb" + "github.com/milvus-io/milvus/pkg/v2/util/typeutil" +) + +func TestGenerateBalancedInt64PKs(t *testing.T) { + t.Run("basic_functionality", func(t *testing.T) { + numRows := 100 + numChannels := 4 + pks, nextPK := GenerateBalancedInt64PKs(numRows, numChannels, 1) + + assert.Equal(t, numRows, len(pks), "should generate correct number of PKs") + assert.Greater(t, nextPK, int64(numRows), "nextPK should be greater than numRows") + }) + + t.Run("zero_channels_defaults_to_one", func(t *testing.T) { + numRows := 10 + pks, _ := GenerateBalancedInt64PKs(numRows, 0, 1) + + assert.Equal(t, numRows, len(pks), "should generate correct number of PKs") + }) + + t.Run("negative_channels_defaults_to_one", func(t *testing.T) { + numRows := 10 + pks, _ := GenerateBalancedInt64PKs(numRows, -5, 1) + + assert.Equal(t, numRows, len(pks), "should generate correct number of PKs") + }) + + t.Run("balanced_distribution_by_hash", func(t *testing.T) { + numRows := 100 + numChannels := 4 + pks, _ := GenerateBalancedInt64PKs(numRows, numChannels, 1) + + // Verify distribution by hashing PKs + channelCounts := make(map[int]int) + for _, pk := range pks { + hash := hashInt64ForChannel(pk) + ch := int(hash % uint32(numChannels)) + channelCounts[ch]++ + } + + // Each channel should have 25 PKs (100/4 = 25) + expectedCount := numRows / numChannels + for ch := 0; ch < numChannels; ch++ { + assert.Equal(t, expectedCount, channelCounts[ch], + "channel %d should have %d PKs", ch, expectedCount) + } + }) + + t.Run("remainder_distribution", func(t *testing.T) { + numRows := 10 + numChannels := 3 + pks, _ := GenerateBalancedInt64PKs(numRows, numChannels, 1) + + channelCounts := make(map[int]int) + for _, pk := range pks { + hash := hashInt64ForChannel(pk) + ch := int(hash % uint32(numChannels)) + channelCounts[ch]++ + } + + // 10 / 3 = 3 base, remainder = 1 + // Channel 0: 4 PKs (3 + 1 from remainder) + // Channel 1: 3 PKs + // Channel 2: 3 PKs + assert.Equal(t, 4, channelCounts[0], "channel 0 should have 4 PKs") + assert.Equal(t, 3, channelCounts[1], "channel 1 should have 3 PKs") + assert.Equal(t, 3, channelCounts[2], "channel 2 should have 3 PKs") + }) + + t.Run("unique_pks", func(t *testing.T) { + numRows := 100 + numChannels := 4 + pks, _ := GenerateBalancedInt64PKs(numRows, numChannels, 1) + + // Verify all PKs are unique + seen := make(map[int64]bool) + for _, pk := range pks { + assert.False(t, seen[pk], "PK %d should be unique", pk) + seen[pk] = true + } + }) + + t.Run("positive_pks", func(t *testing.T) { + numRows := 50 + numChannels := 5 + pks, _ := GenerateBalancedInt64PKs(numRows, numChannels, 1) + + for _, pk := range pks { + assert.Greater(t, pk, int64(0), "PKs should be positive") + } + }) + + t.Run("continuation_no_duplicates", func(t *testing.T) { + numRows := 100 + numChannels := 4 + + // First call + pks1, nextPK := GenerateBalancedInt64PKs(numRows, numChannels, 1) + + // Second call continues from nextPK + pks2, _ := GenerateBalancedInt64PKs(numRows, numChannels, nextPK) + + // Verify no overlap between pks1 and pks2 + seen := make(map[int64]bool) + for _, pk := range pks1 { + seen[pk] = true + } + for _, pk := range pks2 { + assert.False(t, seen[pk], "duplicate PK found: %d", pk) + } + }) + + t.Run("custom_start_pk", func(t *testing.T) { + numRows := 10 + numChannels := 2 + startPK := int64(1000) + + pks, nextPK := GenerateBalancedInt64PKs(numRows, numChannels, startPK) + + // All PKs should be >= startPK + for _, pk := range pks { + assert.GreaterOrEqual(t, pk, startPK, "PK should be >= startPK") + } + assert.Greater(t, nextPK, startPK, "nextPK should be > startPK") + }) +} + +func TestHashInt64ForChannel(t *testing.T) { + t.Run("consistency", func(t *testing.T) { + // Same input should always produce same output + pk := int64(12345) + hash1 := hashInt64ForChannel(pk) + hash2 := hashInt64ForChannel(pk) + + assert.Equal(t, hash1, hash2, "same input should produce same hash") + }) + + t.Run("different_inputs_different_hashes", func(t *testing.T) { + // Different inputs should generally produce different hashes + // (with very high probability) + hashes := make(map[uint32]int64) + collisions := 0 + + for pk := int64(1); pk <= 1000; pk++ { + hash := hashInt64ForChannel(pk) + if existingPK, exists := hashes[hash]; exists { + collisions++ + t.Logf("collision: PK %d and %d both hash to %d", pk, existingPK, hash) + } + hashes[hash] = pk + } + + // Allow a small number of collisions (hash collisions are possible) + assert.Less(t, collisions, 10, + "too many hash collisions for first 1000 PKs") + }) + + t.Run("non_negative_result", func(t *testing.T) { + // The hash should always be non-negative (due to & 0x7fffffff) + testCases := []int64{0, 1, -1, 100, -100, 1 << 62, -(1 << 62)} + + for _, pk := range testCases { + hash := hashInt64ForChannel(pk) + assert.GreaterOrEqual(t, hash, uint32(0), + "hash for PK %d should be non-negative", pk) + } + }) + + t.Run("distribution_across_channels", func(t *testing.T) { + // Test that hashes distribute well across channels + numChannels := 8 + channelCounts := make(map[int]int) + + for pk := int64(1); pk <= 8000; pk++ { + hash := hashInt64ForChannel(pk) + ch := int(hash % uint32(numChannels)) + channelCounts[ch]++ + } + + // Each channel should have roughly 1000 items (8000/8) + // Allow 20% variance + expectedCount := 1000 + tolerance := 200 + + for ch := 0; ch < numChannels; ch++ { + count := channelCounts[ch] + assert.Greater(t, count, expectedCount-tolerance, + "channel %d has too few items: %d", ch, count) + assert.Less(t, count, expectedCount+tolerance, + "channel %d has too many items: %d", ch, count) + } + }) +} + +func TestGenerateChannelBalancedPrimaryKeys(t *testing.T) { + t.Run("int64_type", func(t *testing.T) { + numRows := 100 + numChannels := 4 + fieldName := "test_pk" + + fieldData, nextPK := GenerateChannelBalancedPrimaryKeys(fieldName, schemapb.DataType_Int64, numRows, numChannels, 1) + + assert.Equal(t, schemapb.DataType_Int64, fieldData.GetType()) + assert.Equal(t, fieldName, fieldData.GetFieldName()) + assert.Greater(t, nextPK, int64(0), "nextPK should be positive") + + pks := fieldData.GetScalars().GetLongData().GetData() + assert.Equal(t, numRows, len(pks)) + + // Verify balanced distribution + channelCounts := make(map[int]int) + for _, pk := range pks { + hash := hashInt64ForChannel(pk) + ch := int(hash % uint32(numChannels)) + channelCounts[ch]++ + } + + expectedCount := numRows / numChannels + for ch := 0; ch < numChannels; ch++ { + assert.Equal(t, expectedCount, channelCounts[ch], + "channel %d should have %d PKs", ch, expectedCount) + } + }) + + t.Run("varchar_type", func(t *testing.T) { + numRows := 100 + numChannels := 4 + fieldName := "test_varchar_pk" + + fieldData, nextPK := GenerateChannelBalancedPrimaryKeys(fieldName, schemapb.DataType_VarChar, numRows, numChannels, 1) + + assert.Equal(t, schemapb.DataType_VarChar, fieldData.GetType()) + assert.Equal(t, fieldName, fieldData.GetFieldName()) + assert.Greater(t, nextPK, int64(0), "nextPK should be positive") + + pks := fieldData.GetScalars().GetStringData().GetData() + assert.Equal(t, numRows, len(pks)) + + // Verify balanced distribution + channelCounts := make(map[int]int) + for _, pk := range pks { + hash := hashVarCharForChannel(pk) + ch := int(hash % uint32(numChannels)) + channelCounts[ch]++ + } + + expectedCount := numRows / numChannels + for ch := 0; ch < numChannels; ch++ { + assert.Equal(t, expectedCount, channelCounts[ch], + "channel %d should have %d PKs", ch, expectedCount) + } + }) + + t.Run("string_type_as_varchar", func(t *testing.T) { + numRows := 50 + numChannels := 2 + fieldName := "string_pk" + + fieldData, _ := GenerateChannelBalancedPrimaryKeys(fieldName, schemapb.DataType_String, numRows, numChannels, 1) + + // String type should be treated as VarChar + assert.Equal(t, schemapb.DataType_VarChar, fieldData.GetType()) + assert.Equal(t, fieldName, fieldData.GetFieldName()) + + pks := fieldData.GetScalars().GetStringData().GetData() + assert.Equal(t, numRows, len(pks)) + }) + + t.Run("unsupported_type_panics", func(t *testing.T) { + assert.Panics(t, func() { + GenerateChannelBalancedPrimaryKeys("test", schemapb.DataType_Float, 10, 2, 1) + }, "unsupported type should panic") + }) + + t.Run("continuation_no_duplicates", func(t *testing.T) { + numRows := 100 + numChannels := 4 + fieldName := "test_pk" + + // First call + fieldData1, nextPK := GenerateChannelBalancedPrimaryKeys(fieldName, schemapb.DataType_Int64, numRows, numChannels, 1) + pks1 := fieldData1.GetScalars().GetLongData().GetData() + + // Second call continues from nextPK + fieldData2, _ := GenerateChannelBalancedPrimaryKeys(fieldName, schemapb.DataType_Int64, numRows, numChannels, nextPK) + pks2 := fieldData2.GetScalars().GetLongData().GetData() + + // Verify no overlap + seen := make(map[int64]bool) + for _, pk := range pks1 { + seen[pk] = true + } + for _, pk := range pks2 { + assert.False(t, seen[pk], "duplicate PK found: %d", pk) + } + }) +} + +func TestGenerateBalancedVarCharPKs(t *testing.T) { + t.Run("basic_functionality", func(t *testing.T) { + numRows := 100 + numChannels := 4 + pks, nextIndex := GenerateBalancedVarCharPKs(numRows, numChannels, 1) + + assert.Equal(t, numRows, len(pks), "should generate correct number of PKs") + assert.Greater(t, nextIndex, numRows, "nextIndex should be greater than numRows") + }) + + t.Run("zero_channels_defaults_to_one", func(t *testing.T) { + numRows := 10 + pks, _ := GenerateBalancedVarCharPKs(numRows, 0, 1) + + assert.Equal(t, numRows, len(pks), "should generate correct number of PKs") + }) + + t.Run("negative_channels_defaults_to_one", func(t *testing.T) { + numRows := 10 + pks, _ := GenerateBalancedVarCharPKs(numRows, -5, 1) + + assert.Equal(t, numRows, len(pks), "should generate correct number of PKs") + }) + + t.Run("balanced_distribution_by_hash", func(t *testing.T) { + numRows := 100 + numChannels := 4 + pks, _ := GenerateBalancedVarCharPKs(numRows, numChannels, 1) + + // Verify distribution by hashing PKs + channelCounts := make(map[int]int) + for _, pk := range pks { + hash := hashVarCharForChannel(pk) + ch := int(hash % uint32(numChannels)) + channelCounts[ch]++ + } + + // Each channel should have 25 PKs (100/4 = 25) + expectedCount := numRows / numChannels + for ch := 0; ch < numChannels; ch++ { + assert.Equal(t, expectedCount, channelCounts[ch], + "channel %d should have %d PKs", ch, expectedCount) + } + }) + + t.Run("remainder_distribution", func(t *testing.T) { + numRows := 10 + numChannels := 3 + pks, _ := GenerateBalancedVarCharPKs(numRows, numChannels, 1) + + channelCounts := make(map[int]int) + for _, pk := range pks { + hash := hashVarCharForChannel(pk) + ch := int(hash % uint32(numChannels)) + channelCounts[ch]++ + } + + // 10 / 3 = 3 base, remainder = 1 + // Channel 0: 4 PKs (3 + 1 from remainder) + // Channel 1: 3 PKs + // Channel 2: 3 PKs + assert.Equal(t, 4, channelCounts[0], "channel 0 should have 4 PKs") + assert.Equal(t, 3, channelCounts[1], "channel 1 should have 3 PKs") + assert.Equal(t, 3, channelCounts[2], "channel 2 should have 3 PKs") + }) + + t.Run("unique_pks", func(t *testing.T) { + numRows := 100 + numChannels := 4 + pks, _ := GenerateBalancedVarCharPKs(numRows, numChannels, 1) + + // Verify all PKs are unique + seen := make(map[string]bool) + for _, pk := range pks { + assert.False(t, seen[pk], "PK %s should be unique", pk) + seen[pk] = true + } + }) + + t.Run("non_empty_pks", func(t *testing.T) { + numRows := 50 + numChannels := 5 + pks, _ := GenerateBalancedVarCharPKs(numRows, numChannels, 1) + + for _, pk := range pks { + assert.NotEmpty(t, pk, "PKs should not be empty") + } + }) + + t.Run("continuation_no_duplicates", func(t *testing.T) { + numRows := 100 + numChannels := 4 + + // First call + pks1, nextIndex := GenerateBalancedVarCharPKs(numRows, numChannels, 1) + + // Second call continues from nextIndex + pks2, _ := GenerateBalancedVarCharPKs(numRows, numChannels, nextIndex) + + // Verify no overlap between pks1 and pks2 + seen := make(map[string]bool) + for _, pk := range pks1 { + seen[pk] = true + } + for _, pk := range pks2 { + assert.False(t, seen[pk], "duplicate PK found: %s", pk) + } + }) +} + +func TestHashVarCharForChannel(t *testing.T) { + t.Run("consistency", func(t *testing.T) { + // Same input should always produce same output + pk := "test_pk_12345" + hash1 := hashVarCharForChannel(pk) + hash2 := hashVarCharForChannel(pk) + + assert.Equal(t, hash1, hash2, "same input should produce same hash") + }) + + t.Run("different_inputs_different_hashes", func(t *testing.T) { + // Different inputs should generally produce different hashes + hashes := make(map[uint32]string) + collisions := 0 + + for i := 1; i <= 1000; i++ { + // Use unique pk format: pk_ + pk := fmt.Sprintf("pk_%d", i) + hash := hashVarCharForChannel(pk) + if existingPK, exists := hashes[hash]; exists { + collisions++ + t.Logf("collision: PK %s and %s both hash to %d", pk, existingPK, hash) + } + hashes[hash] = pk + } + + // Allow some collisions (hash collisions are expected) + assert.Less(t, collisions, 50, + "too many hash collisions for first 1000 PKs") + }) + + t.Run("substring_limit", func(t *testing.T) { + // Strings longer than 100 chars should only hash first 100 chars + base := "a" + longStr := "" + for i := 0; i < 150; i++ { + longStr += base + } + shortStr := longStr[:100] + + // Hash of long string should equal hash of first 100 chars + hashLong := hashVarCharForChannel(longStr) + hashShort := hashVarCharForChannel(shortStr) + + assert.Equal(t, hashShort, hashLong, + "hash of long string should equal hash of first 100 chars") + }) + + t.Run("distribution_across_channels", func(t *testing.T) { + // Test that hashes distribute well across channels + numChannels := 8 + channelCounts := make(map[int]int) + + for i := 1; i <= 8000; i++ { + // Use unique pk format for distribution test + pk := fmt.Sprintf("distribution_test_pk_%d", i) + hash := hashVarCharForChannel(pk) + ch := int(hash % uint32(numChannels)) + channelCounts[ch]++ + } + + // Each channel should have roughly 1000 items (8000/8) + // Allow 20% variance + expectedCount := 1000 + tolerance := 200 + + for ch := 0; ch < numChannels; ch++ { + count := channelCounts[ch] + assert.Greater(t, count, expectedCount-tolerance, + "channel %d has too few items: %d", ch, count) + assert.Less(t, count, expectedCount+tolerance, + "channel %d has too many items: %d", ch, count) + } + }) +} + +// TestHashPK2ChannelsIntegration verifies that GenerateChannelBalancedPrimaryKeys +// produces PKs that are evenly distributed when using the actual HashPK2Channels function. +// This is an end-to-end test to ensure our hash implementation matches Milvus's internal implementation. +func TestHashPK2ChannelsIntegration(t *testing.T) { + t.Run("int64_pk_balanced_with_HashPK2Channels", func(t *testing.T) { + numRows := 100 + numChannels := 4 + fieldName := "test_pk" + + // Generate balanced PKs + fieldData, _ := GenerateChannelBalancedPrimaryKeys(fieldName, schemapb.DataType_Int64, numRows, numChannels, 1) + pks := fieldData.GetScalars().GetLongData().GetData() + + // Create schemapb.IDs for HashPK2Channels + ids := &schemapb.IDs{ + IdField: &schemapb.IDs_IntId{ + IntId: &schemapb.LongArray{ + Data: pks, + }, + }, + } + + // Create shard names + shardNames := make([]string, numChannels) + for i := 0; i < numChannels; i++ { + shardNames[i] = fmt.Sprintf("shard_%d", i) + } + + // Use actual HashPK2Channels to get channel assignments + channelIndices := typeutil.HashPK2Channels(ids, shardNames) + + // Count distribution + channelCounts := make(map[uint32]int) + for _, ch := range channelIndices { + channelCounts[ch]++ + } + + // Verify balanced distribution: each channel should have exactly numRows/numChannels + expectedCount := numRows / numChannels + for ch := 0; ch < numChannels; ch++ { + assert.Equal(t, expectedCount, channelCounts[uint32(ch)], + "channel %d should have exactly %d PKs via HashPK2Channels", ch, expectedCount) + } + }) + + t.Run("int64_pk_with_remainder", func(t *testing.T) { + numRows := 10 + numChannels := 3 + fieldName := "test_pk" + + fieldData, _ := GenerateChannelBalancedPrimaryKeys(fieldName, schemapb.DataType_Int64, numRows, numChannels, 1) + pks := fieldData.GetScalars().GetLongData().GetData() + + ids := &schemapb.IDs{ + IdField: &schemapb.IDs_IntId{ + IntId: &schemapb.LongArray{ + Data: pks, + }, + }, + } + + shardNames := make([]string, numChannels) + for i := 0; i < numChannels; i++ { + shardNames[i] = fmt.Sprintf("shard_%d", i) + } + + channelIndices := typeutil.HashPK2Channels(ids, shardNames) + + channelCounts := make(map[uint32]int) + for _, ch := range channelIndices { + channelCounts[ch]++ + } + + // 10 / 3 = 3 base, remainder = 1 + // Channel 0: 4, Channel 1: 3, Channel 2: 3 + assert.Equal(t, 4, channelCounts[0], "channel 0 should have 4 PKs") + assert.Equal(t, 3, channelCounts[1], "channel 1 should have 3 PKs") + assert.Equal(t, 3, channelCounts[2], "channel 2 should have 3 PKs") + }) + + t.Run("varchar_pk_balanced_with_HashPK2Channels", func(t *testing.T) { + numRows := 100 + numChannels := 4 + fieldName := "test_varchar_pk" + + fieldData, _ := GenerateChannelBalancedPrimaryKeys(fieldName, schemapb.DataType_VarChar, numRows, numChannels, 1) + pks := fieldData.GetScalars().GetStringData().GetData() + + ids := &schemapb.IDs{ + IdField: &schemapb.IDs_StrId{ + StrId: &schemapb.StringArray{ + Data: pks, + }, + }, + } + + shardNames := make([]string, numChannels) + for i := 0; i < numChannels; i++ { + shardNames[i] = fmt.Sprintf("shard_%d", i) + } + + channelIndices := typeutil.HashPK2Channels(ids, shardNames) + + channelCounts := make(map[uint32]int) + for _, ch := range channelIndices { + channelCounts[ch]++ + } + + expectedCount := numRows / numChannels + for ch := 0; ch < numChannels; ch++ { + assert.Equal(t, expectedCount, channelCounts[uint32(ch)], + "channel %d should have exactly %d PKs via HashPK2Channels", ch, expectedCount) + } + }) + + t.Run("varchar_pk_with_remainder", func(t *testing.T) { + numRows := 10 + numChannels := 3 + fieldName := "test_varchar_pk" + + fieldData, _ := GenerateChannelBalancedPrimaryKeys(fieldName, schemapb.DataType_VarChar, numRows, numChannels, 1) + pks := fieldData.GetScalars().GetStringData().GetData() + + ids := &schemapb.IDs{ + IdField: &schemapb.IDs_StrId{ + StrId: &schemapb.StringArray{ + Data: pks, + }, + }, + } + + shardNames := make([]string, numChannels) + for i := 0; i < numChannels; i++ { + shardNames[i] = fmt.Sprintf("shard_%d", i) + } + + channelIndices := typeutil.HashPK2Channels(ids, shardNames) + + channelCounts := make(map[uint32]int) + for _, ch := range channelIndices { + channelCounts[ch]++ + } + + // 10 / 3 = 3 base, remainder = 1 + assert.Equal(t, 4, channelCounts[0], "channel 0 should have 4 PKs") + assert.Equal(t, 3, channelCounts[1], "channel 1 should have 3 PKs") + assert.Equal(t, 3, channelCounts[2], "channel 2 should have 3 PKs") + }) + + t.Run("large_scale_int64_distribution", func(t *testing.T) { + numRows := 1000 + numChannels := 8 + fieldName := "test_pk" + + fieldData, _ := GenerateChannelBalancedPrimaryKeys(fieldName, schemapb.DataType_Int64, numRows, numChannels, 1) + pks := fieldData.GetScalars().GetLongData().GetData() + + ids := &schemapb.IDs{ + IdField: &schemapb.IDs_IntId{ + IntId: &schemapb.LongArray{ + Data: pks, + }, + }, + } + + shardNames := make([]string, numChannels) + for i := 0; i < numChannels; i++ { + shardNames[i] = fmt.Sprintf("shard_%d", i) + } + + channelIndices := typeutil.HashPK2Channels(ids, shardNames) + + channelCounts := make(map[uint32]int) + for _, ch := range channelIndices { + channelCounts[ch]++ + } + + // Each channel should have exactly 125 PKs (1000/8) + expectedCount := numRows / numChannels + for ch := 0; ch < numChannels; ch++ { + assert.Equal(t, expectedCount, channelCounts[uint32(ch)], + "channel %d should have exactly %d PKs via HashPK2Channels", ch, expectedCount) + } + }) + + t.Run("large_scale_varchar_distribution", func(t *testing.T) { + numRows := 1000 + numChannels := 8 + fieldName := "test_varchar_pk" + + fieldData, _ := GenerateChannelBalancedPrimaryKeys(fieldName, schemapb.DataType_VarChar, numRows, numChannels, 1) + pks := fieldData.GetScalars().GetStringData().GetData() + + ids := &schemapb.IDs{ + IdField: &schemapb.IDs_StrId{ + StrId: &schemapb.StringArray{ + Data: pks, + }, + }, + } + + shardNames := make([]string, numChannels) + for i := 0; i < numChannels; i++ { + shardNames[i] = fmt.Sprintf("shard_%d", i) + } + + channelIndices := typeutil.HashPK2Channels(ids, shardNames) + + channelCounts := make(map[uint32]int) + for _, ch := range channelIndices { + channelCounts[ch]++ + } + + // Each channel should have exactly 125 PKs (1000/8) + expectedCount := numRows / numChannels + for ch := 0; ch < numChannels; ch++ { + assert.Equal(t, expectedCount, channelCounts[uint32(ch)], + "channel %d should have exactly %d PKs via HashPK2Channels", ch, expectedCount) + } + }) +}