From 4781db8a2a07739f7c08f4b3cb5452d2083decbd Mon Sep 17 00:00:00 2001 From: godchen Date: Tue, 12 Apr 2022 22:19:34 +0800 Subject: [PATCH] Add datanode import (#16414) Signed-off-by: godchen0212 --- internal/datacoord/services.go | 3 +- internal/datanode/data_node.go | 216 ++++++++++++++++++ internal/datanode/data_node_test.go | 35 ++- internal/datanode/mock_test.go | 19 ++ internal/storage/local_chunk_manager.go | 14 +- internal/storage/local_chunk_manager_test.go | 12 +- internal/storage/minio_chunk_manager.go | 14 +- internal/storage/minio_chunk_manager_test.go | 12 +- internal/storage/types.go | 17 +- internal/storage/vector_chunk_manager.go | 19 +- internal/storage/vector_chunk_manager_test.go | 11 +- internal/util/importutil/import_wrapper.go | 60 ++--- .../util/importutil/import_wrapper_test.go | 150 ++++++------ internal/util/importutil/json_handler.go | 186 +++++++-------- internal/util/importutil/json_handler_test.go | 25 +- internal/util/importutil/json_parser.go | 40 ++-- internal/util/importutil/numpy_adapter.go | 11 + internal/util/importutil/numpy_parser_test.go | 5 +- 18 files changed, 600 insertions(+), 249 deletions(-) diff --git a/internal/datacoord/services.go b/internal/datacoord/services.go index 497c4d47da..6cc9a2126b 100644 --- a/internal/datacoord/services.go +++ b/internal/datacoord/services.go @@ -24,6 +24,8 @@ import ( "sync/atomic" "time" + "go.uber.org/zap" + "github.com/milvus-io/milvus/internal/common" "github.com/milvus-io/milvus/internal/log" "github.com/milvus-io/milvus/internal/proto/commonpb" @@ -33,7 +35,6 @@ import ( "github.com/milvus-io/milvus/internal/util/logutil" "github.com/milvus-io/milvus/internal/util/metricsinfo" "github.com/milvus-io/milvus/internal/util/trace" - "go.uber.org/zap" ) const moduleName = "DataCoord" diff --git a/internal/datanode/data_node.go b/internal/datanode/data_node.go index 3d77667047..b0db884cb4 100644 --- a/internal/datanode/data_node.go +++ b/internal/datanode/data_node.go @@ -39,6 +39,7 @@ import ( clientv3 "go.etcd.io/etcd/client/v3" "go.uber.org/zap" + allocator2 "github.com/milvus-io/milvus/internal/allocator" "github.com/milvus-io/milvus/internal/common" "github.com/milvus-io/milvus/internal/kv" etcdkv "github.com/milvus-io/milvus/internal/kv/etcd" @@ -46,12 +47,15 @@ import ( "github.com/milvus-io/milvus/internal/metrics" "github.com/milvus-io/milvus/internal/proto/commonpb" "github.com/milvus-io/milvus/internal/proto/datapb" + "github.com/milvus-io/milvus/internal/proto/etcdpb" "github.com/milvus-io/milvus/internal/proto/internalpb" "github.com/milvus-io/milvus/internal/proto/milvuspb" "github.com/milvus-io/milvus/internal/proto/rootcoordpb" + "github.com/milvus-io/milvus/internal/proto/schemapb" "github.com/milvus-io/milvus/internal/storage" "github.com/milvus-io/milvus/internal/types" "github.com/milvus-io/milvus/internal/util/dependency" + "github.com/milvus-io/milvus/internal/util/importutil" "github.com/milvus-io/milvus/internal/util/logutil" "github.com/milvus-io/milvus/internal/util/metricsinfo" "github.com/milvus-io/milvus/internal/util/paramtable" @@ -787,9 +791,221 @@ func (node *DataNode) Import(ctx context.Context, req *datapb.ImportTaskRequest) Reason: msgDataNodeIsUnhealthy(Params.DataNodeCfg.NodeID), }, nil } + rep, err := node.rootCoord.AllocTimestamp(node.ctx, &rootcoordpb.AllocTimestampRequest{ + Base: &commonpb.MsgBase{ + MsgType: commonpb.MsgType_RequestTSO, + MsgID: 0, + Timestamp: 0, + SourceID: node.NodeID, + }, + Count: 1, + }) + if rep.Status.ErrorCode != commonpb.ErrorCode_Success || err != nil { + if err != nil { + return &commonpb.Status{ + ErrorCode: commonpb.ErrorCode_UnexpectedError, + Reason: "DataNode alloc ts failed", + }, nil + } + } + + ts := rep.GetTimestamp() + + metaService := newMetaService(node.rootCoord, req.GetImportTask().GetCollectionId()) + schema, err := metaService.getCollectionSchema(ctx, req.GetImportTask().GetCollectionId(), 0) + if err != nil { + return &commonpb.Status{ + ErrorCode: commonpb.ErrorCode_UnexpectedError, + Reason: err.Error(), + }, nil + } + idAllocator, err := allocator2.NewIDAllocator(node.ctx, node.rootCoord, Params.DataNodeCfg.NodeID) + importWrapper := importutil.NewImportWrapper(ctx, schema, 2, Params.DataNodeCfg.FlushInsertBufferSize/(1024*1024), idAllocator, node.chunkManager, importFlushReqFunc(node, req, schema, ts)) + err = importWrapper.Import(req.GetImportTask().GetFiles(), req.GetImportTask().GetRowBased(), false) + if err != nil { + return &commonpb.Status{ + ErrorCode: commonpb.ErrorCode_UnexpectedError, + Reason: err.Error(), + }, nil + } resp := &commonpb.Status{ ErrorCode: commonpb.ErrorCode_Success, } return resp, nil } + +type importFlushFunc func(fields map[storage.FieldID]storage.FieldData) error + +func importFlushReqFunc(node *DataNode, req *datapb.ImportTaskRequest, schema *schemapb.CollectionSchema, ts Timestamp) importFlushFunc { + return func(fields map[storage.FieldID]storage.FieldData) error { + segReqs := []*datapb.SegmentIDRequest{ + { + ChannelName: "test-channel", + Count: 1, + CollectionID: req.GetImportTask().GetCollectionId(), + PartitionID: req.GetImportTask().GetCollectionId(), + }, + } + segmentIDReq := &datapb.AssignSegmentIDRequest{ + NodeID: 0, + PeerRole: typeutil.ProxyRole, + SegmentIDRequests: segReqs, + } + + resp, err := node.dataCoord.AssignSegmentID(context.Background(), segmentIDReq) + if err != nil { + return fmt.Errorf("syncSegmentID Failed:%w", err) + } + + if resp.Status.ErrorCode != commonpb.ErrorCode_Success { + return fmt.Errorf("syncSegmentID Failed:%s", resp.Status.Reason) + } + segmentID := resp.SegIDAssignments[0].SegID + + var rowNum int + for _, field := range fields { + rowNum = field.RowNum() + break + } + tsFieldData := make([]int64, rowNum) + for i := range tsFieldData { + tsFieldData[i] = int64(ts) + } + fields[common.TimeStampField] = &storage.Int64FieldData{ + Data: tsFieldData, + NumRows: []int64{int64(rowNum)}, + } + var pkFieldID int64 + for _, field := range schema.Fields { + if field.IsPrimaryKey { + pkFieldID = field.GetFieldID() + break + } + } + fields[common.RowIDField] = fields[pkFieldID] + + data := BufferData{buffer: &InsertData{ + Data: fields, + }} + meta := &etcdpb.CollectionMeta{ + ID: req.GetImportTask().GetCollectionId(), + Schema: schema, + } + inCodec := storage.NewInsertCodec(meta) + + binLogs, statsBinlogs, err := inCodec.Serialize(req.GetImportTask().GetPartitionId(), segmentID, data.buffer) + if err != nil { + return err + } + + var alloc allocatorInterface = newAllocator(node.rootCoord) + start, _, err := alloc.allocIDBatch(uint32(len(binLogs))) + if err != nil { + return err + } + + field2Insert := make(map[UniqueID]*datapb.Binlog, len(binLogs)) + kvs := make(map[string][]byte, len(binLogs)) + field2Logidx := make(map[UniqueID]UniqueID, len(binLogs)) + for idx, blob := range binLogs { + fieldID, err := strconv.ParseInt(blob.GetKey(), 10, 64) + if err != nil { + log.Error("Flush failed ... cannot parse string to fieldID ..", zap.Error(err)) + return err + } + + logidx := start + int64(idx) + + // no error raise if alloc=false + k := JoinIDPath(req.GetImportTask().GetCollectionId(), req.GetImportTask().GetPartitionId(), segmentID, fieldID, logidx) + + key := path.Join(Params.DataNodeCfg.InsertBinlogRootPath, k) + kvs[key] = blob.Value[:] + field2Insert[fieldID] = &datapb.Binlog{ + EntriesNum: data.size, + TimestampFrom: 0, //TODO + TimestampTo: 0, //TODO, + LogPath: key, + LogSize: int64(len(blob.Value)), + } + field2Logidx[fieldID] = logidx + } + + field2Stats := make(map[UniqueID]*datapb.Binlog) + // write stats binlog + for _, blob := range statsBinlogs { + fieldID, err := strconv.ParseInt(blob.GetKey(), 10, 64) + if err != nil { + log.Error("Flush failed ... cannot parse string to fieldID ..", zap.Error(err)) + return err + } + + logidx := field2Logidx[fieldID] + + // no error raise if alloc=false + k := JoinIDPath(req.GetImportTask().GetCollectionId(), req.GetImportTask().GetPartitionId(), segmentID, fieldID, logidx) + + key := path.Join(Params.DataNodeCfg.StatsBinlogRootPath, k) + kvs[key] = blob.Value + field2Stats[fieldID] = &datapb.Binlog{ + EntriesNum: 0, + TimestampFrom: 0, //TODO + TimestampTo: 0, //TODO, + LogPath: key, + LogSize: int64(len(blob.Value)), + } + } + + err = node.chunkManager.MultiWrite(kvs) + if err != nil { + return err + } + var ( + fieldInsert []*datapb.FieldBinlog + fieldStats []*datapb.FieldBinlog + ) + + for k, v := range field2Insert { + fieldInsert = append(fieldInsert, &datapb.FieldBinlog{FieldID: k, Binlogs: []*datapb.Binlog{v}}) + } + for k, v := range field2Stats { + fieldStats = append(fieldStats, &datapb.FieldBinlog{FieldID: k, Binlogs: []*datapb.Binlog{v}}) + } + + req := &datapb.SaveBinlogPathsRequest{ + Base: &commonpb.MsgBase{ + MsgType: 0, //TODO msg type + MsgID: 0, //TODO msg id + Timestamp: 0, //TODO time stamp + SourceID: Params.DataNodeCfg.NodeID, + }, + SegmentID: segmentID, + CollectionID: req.ImportTask.GetCollectionId(), + Field2BinlogPaths: fieldInsert, + Field2StatslogPaths: fieldStats, + Importing: true, + } + + err = retry.Do(context.Background(), func() error { + rsp, err := node.dataCoord.SaveBinlogPaths(context.Background(), req) + // should be network issue, return error and retry + if err != nil { + return fmt.Errorf(err.Error()) + } + + // TODO should retry only when datacoord status is unhealthy + if rsp.ErrorCode != commonpb.ErrorCode_Success { + return fmt.Errorf("data service save bin log path failed, reason = %s", rsp.Reason) + } + return nil + }) + if err != nil { + log.Warn("failed to SaveBinlogPaths", zap.Error(err)) + return err + } + + return nil + } + +} diff --git a/internal/datanode/data_node_test.go b/internal/datanode/data_node_test.go index 574b24ce20..1df36ef9c0 100644 --- a/internal/datanode/data_node_test.go +++ b/internal/datanode/data_node_test.go @@ -32,6 +32,7 @@ import ( etcdkv "github.com/milvus-io/milvus/internal/kv/etcd" "github.com/milvus-io/milvus/internal/log" "github.com/milvus-io/milvus/internal/mq/msgstream" + "github.com/milvus-io/milvus/internal/storage" "github.com/milvus-io/milvus/internal/types" "github.com/milvus-io/milvus/internal/util/dependency" @@ -80,6 +81,9 @@ func TestDataNode(t *testing.T) { err = node.Start() assert.Nil(t, err) + node.chunkManager = storage.NewLocalChunkManager(storage.RootPath("/tmp/lib/milvus")) + Params.DataNodeCfg.NodeID = 1 + t.Run("Test WatchDmChannels ", func(t *testing.T) { emptyNode := &DataNode{} @@ -316,6 +320,34 @@ func TestDataNode(t *testing.T) { }) t.Run("Test Import", func(t *testing.T) { + content := []byte(`{ + "rows":[ + {"bool_field": true, "int8_field": 10, "int16_field": 101, "int32_field": 1001, "int64_field": 10001, "float32_field": 3.14, "float64_field": 1.56, "varChar_field": "hello world", "binary_vector_field": [254, 0, 254, 0], "float_vector_field": [1.1, 1.2]}, + {"bool_field": false, "int8_field": 11, "int16_field": 102, "int32_field": 1002, "int64_field": 10002, "float32_field": 3.15, "float64_field": 2.56, "varChar_field": "hello world", "binary_vector_field": [253, 0, 253, 0], "float_vector_field": [2.1, 2.2]}, + {"bool_field": true, "int8_field": 12, "int16_field": 103, "int32_field": 1003, "int64_field": 10003, "float32_field": 3.16, "float64_field": 3.56, "varChar_field": "hello world", "binary_vector_field": [252, 0, 252, 0], "float_vector_field": [3.1, 3.2]}, + {"bool_field": false, "int8_field": 13, "int16_field": 104, "int32_field": 1004, "int64_field": 10004, "float32_field": 3.17, "float64_field": 4.56, "varChar_field": "hello world", "binary_vector_field": [251, 0, 251, 0], "float_vector_field": [4.1, 4.2]}, + {"bool_field": true, "int8_field": 14, "int16_field": 105, "int32_field": 1005, "int64_field": 10005, "float32_field": 3.18, "float64_field": 5.56, "varChar_field": "hello world", "binary_vector_field": [250, 0, 250, 0], "float_vector_field": [5.1, 5.2]} + ] + }`) + + filePath := "import/rows_1.json" + err = node.chunkManager.Write(filePath, content) + assert.NoError(t, err) + req := &datapb.ImportTaskRequest{ + ImportTask: &datapb.ImportTask{ + CollectionId: 100, + PartitionId: 100, + Files: []string{filePath}, + RowBased: true, + }, + } + stat, err := node.Import(node.ctx, req) + assert.NoError(t, err) + assert.Equal(t, commonpb.ErrorCode_Success, stat.ErrorCode) + }) + + t.Run("Test Import error", func(t *testing.T) { + node.rootCoord = &RootCoordFactory{collectionID: -1} req := &datapb.ImportTaskRequest{ ImportTask: &datapb.ImportTask{ CollectionId: 100, @@ -324,7 +356,7 @@ func TestDataNode(t *testing.T) { } stat, err := node.Import(node.ctx, req) assert.NoError(t, err) - assert.Equal(t, commonpb.ErrorCode_Success, stat.ErrorCode) + assert.Equal(t, commonpb.ErrorCode_UnexpectedError, stat.ErrorCode) }) t.Run("Test BackGroundGC", func(t *testing.T) { @@ -585,7 +617,6 @@ func TestWatchChannel(t *testing.T) { exist := node.flowgraphManager.exist("test3") assert.False(t, exist) }) - } func TestDataNode_GetComponentStates(t *testing.T) { diff --git a/internal/datanode/mock_test.go b/internal/datanode/mock_test.go index 6a6f3efb31..70fb8d1eea 100644 --- a/internal/datanode/mock_test.go +++ b/internal/datanode/mock_test.go @@ -173,6 +173,19 @@ type DataCoordFactory struct { DropVirtualChannelNotSuccess bool } +func (ds *DataCoordFactory) AssignSegmentID(ctx context.Context, req *datapb.AssignSegmentIDRequest) (*datapb.AssignSegmentIDResponse, error) { + return &datapb.AssignSegmentIDResponse{ + Status: &commonpb.Status{ + ErrorCode: commonpb.ErrorCode_Success, + }, + SegIDAssignments: []*datapb.SegmentIDAssignment{ + { + SegID: 666, + }, + }, + }, nil +} + func (ds *DataCoordFactory) CompleteCompaction(ctx context.Context, req *datapb.CompactionResult) (*commonpb.Status, error) { if ds.CompleteCompactionError { return nil, errors.New("Error") @@ -843,6 +856,12 @@ func (m *RootCoordFactory) AllocID(ctx context.Context, in *rootcoordpb.AllocIDR ErrorCode: commonpb.ErrorCode_UnexpectedError, }} + if in.Count == 12 { + resp.Status.ErrorCode = commonpb.ErrorCode_Success + resp.ID = 1 + resp.Count = 12 + } + if m.ID == 0 { resp.Status.Reason = "Zero ID" return resp, nil diff --git a/internal/storage/local_chunk_manager.go b/internal/storage/local_chunk_manager.go index 70dfee531e..2176b81955 100644 --- a/internal/storage/local_chunk_manager.go +++ b/internal/storage/local_chunk_manager.go @@ -48,8 +48,8 @@ func NewLocalChunkManager(opts ...Option) *LocalChunkManager { } } -// GetPath returns the path of local data if exists. -func (lcm *LocalChunkManager) GetPath(filePath string) (string, error) { +// Path returns the path of local data if exists. +func (lcm *LocalChunkManager) Path(filePath string) (string, error) { if !lcm.Exist(filePath) { return "", errors.New("local file cannot be found with filePath:" + filePath) } @@ -57,6 +57,14 @@ func (lcm *LocalChunkManager) GetPath(filePath string) (string, error) { return absPath, nil } +func (lcm *LocalChunkManager) Reader(filePath string) (FileReader, error) { + if !lcm.Exist(filePath) { + return nil, errors.New("local file cannot be found with filePath:" + filePath) + } + absPath := path.Join(lcm.localPath, filePath) + return os.Open(absPath) +} + // Write writes the data to local storage. func (lcm *LocalChunkManager) Write(filePath string, content []byte) error { absPath := path.Join(lcm.localPath, filePath) @@ -181,7 +189,7 @@ func (lcm *LocalChunkManager) Mmap(filePath string) (*mmap.ReaderAt, error) { return mmap.Open(path.Clean(absPath)) } -func (lcm *LocalChunkManager) GetSize(filePath string) (int64, error) { +func (lcm *LocalChunkManager) Size(filePath string) (int64, error) { absPath := path.Join(lcm.localPath, filePath) fi, err := os.Stat(absPath) if err != nil { diff --git a/internal/storage/local_chunk_manager_test.go b/internal/storage/local_chunk_manager_test.go index e97cddb005..087e084fc3 100644 --- a/internal/storage/local_chunk_manager_test.go +++ b/internal/storage/local_chunk_manager_test.go @@ -325,7 +325,7 @@ func TestLocalCM(t *testing.T) { assert.Error(t, err) }) - t.Run("test GetSize", func(t *testing.T) { + t.Run("test Size", func(t *testing.T) { testGetSizeRoot := "get_size" testCM := NewLocalChunkManager(RootPath(localPath)) @@ -337,18 +337,18 @@ func TestLocalCM(t *testing.T) { err := testCM.Write(key, value) assert.NoError(t, err) - size, err := testCM.GetSize(key) + size, err := testCM.Size(key) assert.NoError(t, err) assert.Equal(t, size, int64(len(value))) key2 := path.Join(testGetSizeRoot, "TestMemoryKV_GetSize_key2") - size, err = testCM.GetSize(key2) + size, err = testCM.Size(key2) assert.Error(t, err) assert.Equal(t, int64(0), size) }) - t.Run("test GetPath", func(t *testing.T) { + t.Run("test Path", func(t *testing.T) { testGetSizeRoot := "get_path" testCM := NewLocalChunkManager(RootPath(localPath)) @@ -360,13 +360,13 @@ func TestLocalCM(t *testing.T) { err := testCM.Write(key, value) assert.NoError(t, err) - p, err := testCM.GetPath(key) + p, err := testCM.Path(key) assert.NoError(t, err) assert.Equal(t, p, path.Join(localPath, key)) key2 := path.Join(testGetSizeRoot, "TestMemoryKV_GetSize_key2") - p, err = testCM.GetPath(key2) + p, err = testCM.Path(key2) assert.Error(t, err) assert.Equal(t, p, "") }) diff --git a/internal/storage/minio_chunk_manager.go b/internal/storage/minio_chunk_manager.go index 3c1e494459..b563cd2396 100644 --- a/internal/storage/minio_chunk_manager.go +++ b/internal/storage/minio_chunk_manager.go @@ -94,15 +94,23 @@ func newMinioChunkManagerWithConfig(ctx context.Context, c *config) (*MinioChunk return mcm, nil } -// GetPath returns the path of minio data if exists. -func (mcm *MinioChunkManager) GetPath(filePath string) (string, error) { +// Path returns the path of minio data if exists. +func (mcm *MinioChunkManager) Path(filePath string) (string, error) { if !mcm.Exist(filePath) { return "", errors.New("minio file manage cannot be found with filePath:" + filePath) } return filePath, nil } -func (mcm *MinioChunkManager) GetSize(filePath string) (int64, error) { +// Reader returns the path of minio data if exists. +func (mcm *MinioChunkManager) Reader(filePath string) (FileReader, error) { + if !mcm.Exist(filePath) { + return nil, errors.New("minio file manage cannot be found with filePath:" + filePath) + } + return mcm.Client.GetObject(mcm.ctx, mcm.bucketName, filePath, minio.GetObjectOptions{}) +} + +func (mcm *MinioChunkManager) Size(filePath string) (int64, error) { objectInfo, err := mcm.Client.StatObject(mcm.ctx, mcm.bucketName, filePath, minio.StatObjectOptions{}) if err != nil { return 0, err diff --git a/internal/storage/minio_chunk_manager_test.go b/internal/storage/minio_chunk_manager_test.go index cf9efca1b8..546bb90050 100644 --- a/internal/storage/minio_chunk_manager_test.go +++ b/internal/storage/minio_chunk_manager_test.go @@ -354,7 +354,7 @@ func TestMinIOCM(t *testing.T) { assert.Error(t, err) }) - t.Run("test GetSize", func(t *testing.T) { + t.Run("test Size", func(t *testing.T) { testGetSizeRoot := path.Join(testMinIOKVRoot, "get_size") ctx, cancel := context.WithCancel(context.Background()) defer cancel() @@ -369,18 +369,18 @@ func TestMinIOCM(t *testing.T) { err = testCM.Write(key, value) assert.NoError(t, err) - size, err := testCM.GetSize(key) + size, err := testCM.Size(key) assert.NoError(t, err) assert.Equal(t, size, int64(len(value))) key2 := path.Join(testGetSizeRoot, "TestMemoryKV_GetSize_key2") - size, err = testCM.GetSize(key2) + size, err = testCM.Size(key2) assert.Error(t, err) assert.Equal(t, int64(0), size) }) - t.Run("test GetPath", func(t *testing.T) { + t.Run("test Path", func(t *testing.T) { testGetPathRoot := path.Join(testMinIOKVRoot, "get_path") ctx, cancel := context.WithCancel(context.Background()) defer cancel() @@ -395,13 +395,13 @@ func TestMinIOCM(t *testing.T) { err = testCM.Write(key, value) assert.NoError(t, err) - p, err := testCM.GetPath(key) + p, err := testCM.Path(key) assert.NoError(t, err) assert.Equal(t, p, key) key2 := path.Join(testGetPathRoot, "TestMemoryKV_GetSize_key2") - p, err = testCM.GetPath(key2) + p, err = testCM.Path(key2) assert.Error(t, err) assert.Equal(t, p, "") }) diff --git a/internal/storage/types.go b/internal/storage/types.go index 57e4437609..9ce86b1c7e 100644 --- a/internal/storage/types.go +++ b/internal/storage/types.go @@ -12,16 +12,23 @@ package storage import ( + "io" + "golang.org/x/exp/mmap" ) +type FileReader interface { + io.Reader + io.Closer +} + // ChunkManager is to manager chunks. // Include Read, Write, Remove chunks. type ChunkManager interface { - // GetPath returns path of @filePath. - GetPath(filePath string) (string, error) - // GetSize returns path of @filePath. - GetSize(filePath string) (int64, error) + // Path returns path of @filePath. + Path(filePath string) (string, error) + // Size returns path of @filePath. + Size(filePath string) (int64, error) // Write writes @content to @filePath. Write(filePath string, content []byte) error // MultiWrite writes multi @content to @filePath. @@ -30,6 +37,8 @@ type ChunkManager interface { Exist(filePath string) bool // Read reads @filePath and returns content. Read(filePath string) ([]byte, error) + // Reader return a reader for @filePath + Reader(filePath string) (FileReader, error) // MultiRead reads @filePath and returns content. MultiRead(filePaths []string) ([][]byte, error) ListWithPrefix(prefix string) ([]string, error) diff --git a/internal/storage/vector_chunk_manager.go b/internal/storage/vector_chunk_manager.go index 2c59d140f8..7628322de2 100644 --- a/internal/storage/vector_chunk_manager.go +++ b/internal/storage/vector_chunk_manager.go @@ -21,12 +21,13 @@ import ( "io" "sync" + "go.uber.org/zap" + "golang.org/x/exp/mmap" + "github.com/milvus-io/milvus/internal/common" "github.com/milvus-io/milvus/internal/log" "github.com/milvus-io/milvus/internal/proto/etcdpb" "github.com/milvus-io/milvus/internal/util/cache" - "go.uber.org/zap" - "golang.org/x/exp/mmap" ) var ( @@ -116,12 +117,12 @@ func (vcm *VectorChunkManager) deserializeVectorFile(filePath string, content [] // GetPath returns the path of vector data. If cached, return local path. // If not cached return remote path. -func (vcm *VectorChunkManager) GetPath(filePath string) (string, error) { - return vcm.vectorStorage.GetPath(filePath) +func (vcm *VectorChunkManager) Path(filePath string) (string, error) { + return vcm.vectorStorage.Path(filePath) } -func (vcm *VectorChunkManager) GetSize(filePath string) (int64, error) { - return vcm.vectorStorage.GetSize(filePath) +func (vcm *VectorChunkManager) Size(filePath string) (int64, error) { + return vcm.vectorStorage.Size(filePath) } // Write writes the vector data to local cache if cache enabled. @@ -156,7 +157,7 @@ func (vcm *VectorChunkManager) readWithCache(filePath string) ([]byte, error) { if err != nil { return nil, err } - size, err := vcm.cacheStorage.GetSize(filePath) + size, err := vcm.cacheStorage.Size(filePath) if err != nil { return nil, err } @@ -239,6 +240,10 @@ func (vcm *VectorChunkManager) Mmap(filePath string) (*mmap.ReaderAt, error) { return nil, errors.New("the file mmap has not been cached") } +func (vcm *VectorChunkManager) Reader(filePath string) (FileReader, error) { + return nil, errors.New("this method has not been implemented") +} + // ReadAt reads specific position data of vector. If cached, it reads from local. func (vcm *VectorChunkManager) ReadAt(filePath string, off int64, length int64) ([]byte, error) { if vcm.cacheEnable { diff --git a/internal/storage/vector_chunk_manager_test.go b/internal/storage/vector_chunk_manager_test.go index df3e00ce8d..09d768119f 100644 --- a/internal/storage/vector_chunk_manager_test.go +++ b/internal/storage/vector_chunk_manager_test.go @@ -22,11 +22,12 @@ import ( "os" "testing" + "github.com/stretchr/testify/assert" + "github.com/milvus-io/milvus/internal/proto/etcdpb" "github.com/milvus-io/milvus/internal/proto/schemapb" "github.com/milvus-io/milvus/internal/util/paramtable" "github.com/milvus-io/milvus/internal/util/typeutil" - "github.com/stretchr/testify/assert" ) func initMeta() *etcdpb.CollectionMeta { @@ -179,13 +180,13 @@ func TestVectorChunkManager_GetPath(t *testing.T) { key := "1" err = vcm.Write(key, []byte{1}) assert.Nil(t, err) - pathGet, err := vcm.GetPath(key) + pathGet, err := vcm.Path(key) assert.Nil(t, err) assert.Equal(t, pathGet, key) err = vcm.cacheStorage.Write(key, []byte{1}) assert.Nil(t, err) - pathGet, err = vcm.GetPath(key) + pathGet, err = vcm.Path(key) assert.Nil(t, err) assert.Equal(t, pathGet, key) @@ -206,13 +207,13 @@ func TestVectorChunkManager_GetSize(t *testing.T) { key := "1" err = vcm.Write(key, []byte{1}) assert.Nil(t, err) - sizeGet, err := vcm.GetSize(key) + sizeGet, err := vcm.Size(key) assert.Nil(t, err) assert.EqualValues(t, sizeGet, 1) err = vcm.cacheStorage.Write(key, []byte{1}) assert.Nil(t, err) - sizeGet, err = vcm.GetSize(key) + sizeGet, err = vcm.Size(key) assert.Nil(t, err) assert.EqualValues(t, sizeGet, 1) diff --git a/internal/util/importutil/import_wrapper.go b/internal/util/importutil/import_wrapper.go index 3e15824c59..c87245e053 100644 --- a/internal/util/importutil/import_wrapper.go +++ b/internal/util/importutil/import_wrapper.go @@ -4,19 +4,19 @@ import ( "bufio" "context" "errors" - "os" "path" "strconv" "strings" + "go.uber.org/zap" + "go.uber.org/zap/zapcore" + "github.com/milvus-io/milvus/internal/allocator" "github.com/milvus-io/milvus/internal/common" "github.com/milvus-io/milvus/internal/log" "github.com/milvus-io/milvus/internal/proto/schemapb" "github.com/milvus-io/milvus/internal/storage" "github.com/milvus-io/milvus/internal/util/typeutil" - "go.uber.org/zap" - "go.uber.org/zap/zapcore" ) const ( @@ -29,14 +29,15 @@ type ImportWrapper struct { cancel context.CancelFunc // for canceling parse process collectionSchema *schemapb.CollectionSchema // collection schema shardNum int32 // sharding number of the collection - segmentSize int32 // maximum size of a segment in MB + segmentSize int64 // maximum size of a segment in MB rowIDAllocator *allocator.IDAllocator // autoid allocator + chunkManager storage.ChunkManager - callFlushFunc func(fields map[string]storage.FieldData) error // call back function to flush a segment + callFlushFunc func(fields map[storage.FieldID]storage.FieldData) error // call back function to flush a segment } -func NewImportWrapper(ctx context.Context, collectionSchema *schemapb.CollectionSchema, shardNum int32, segmentSize int32, - idAlloc *allocator.IDAllocator, flushFunc func(fields map[string]storage.FieldData) error) *ImportWrapper { +func NewImportWrapper(ctx context.Context, collectionSchema *schemapb.CollectionSchema, shardNum int32, segmentSize int64, + idAlloc *allocator.IDAllocator, cm storage.ChunkManager, flushFunc func(fields map[storage.FieldID]storage.FieldData) error) *ImportWrapper { if collectionSchema == nil { log.Error("import error: collection schema is nil") return nil @@ -67,6 +68,7 @@ func NewImportWrapper(ctx context.Context, collectionSchema *schemapb.Collection segmentSize: segmentSize, rowIDAllocator: idAlloc, callFlushFunc: flushFunc, + chunkManager: cm, } return wrapper @@ -78,10 +80,10 @@ func (p *ImportWrapper) Cancel() error { return nil } -func (p *ImportWrapper) printFieldsDataInfo(fieldsData map[string]storage.FieldData, msg string, files []string) { +func (p *ImportWrapper) printFieldsDataInfo(fieldsData map[storage.FieldID]storage.FieldData, msg string, files []string) { stats := make([]zapcore.Field, 0) for k, v := range fieldsData { - stats = append(stats, zap.Int(k, v.RowNum())) + stats = append(stats, zap.Int(strconv.FormatInt(k, 10), v.RowNum())) } if len(files) > 0 { @@ -112,7 +114,7 @@ func (p *ImportWrapper) Import(filePaths []string, rowBased bool, onlyValidate b if fileType == JSONFileExt { err := func() error { - file, err := os.Open(filePath) + file, err := p.chunkManager.Reader(filePath) if err != nil { return err } @@ -122,7 +124,7 @@ func (p *ImportWrapper) Import(filePaths []string, rowBased bool, onlyValidate b parser := NewJSONParser(p.ctx, p.collectionSchema) var consumer *JSONRowConsumer if !onlyValidate { - flushFunc := func(fields map[string]storage.FieldData) error { + flushFunc := func(fields map[storage.FieldID]storage.FieldData) error { p.printFieldsDataInfo(fields, "import wrapper: prepare to flush segment", filePaths) return p.callFlushFunc(fields) } @@ -153,14 +155,14 @@ func (p *ImportWrapper) Import(filePaths []string, rowBased bool, onlyValidate b rowCount := 0 // function to combine column data into fieldsData - combineFunc := func(fields map[string]storage.FieldData) error { + combineFunc := func(fields map[storage.FieldID]storage.FieldData) error { if len(fields) == 0 { return nil } p.printFieldsDataInfo(fields, "imprort wrapper: combine field data", nil) - fieldNames := make([]string, 0) + fieldNames := make([]storage.FieldID, 0) for k, v := range fields { // ignore 0 row field if v.RowNum() == 0 { @@ -170,12 +172,12 @@ func (p *ImportWrapper) Import(filePaths []string, rowBased bool, onlyValidate b // each column should be only combined once data, ok := fieldsData[k] if ok && data.RowNum() > 0 { - return errors.New("the field " + k + " is duplicated") + return errors.New("the field " + strconv.FormatInt(k, 10) + " is duplicated") } // check the row count. only count non-zero row fields if rowCount > 0 && rowCount != v.RowNum() { - return errors.New("the field " + k + " row count " + strconv.Itoa(v.RowNum()) + " doesn't equal " + strconv.Itoa(rowCount)) + return errors.New("the field " + strconv.FormatInt(k, 10) + " row count " + strconv.Itoa(v.RowNum()) + " doesn't equal " + strconv.Itoa(rowCount)) } rowCount = v.RowNum() @@ -195,7 +197,7 @@ func (p *ImportWrapper) Import(filePaths []string, rowBased bool, onlyValidate b if fileType == JSONFileExt { err := func() error { - file, err := os.Open(filePath) + file, err := p.chunkManager.Reader(filePath) if err != nil { log.Error("imprort error: "+err.Error(), zap.String("filePath", filePath)) return err @@ -224,17 +226,23 @@ func (p *ImportWrapper) Import(filePaths []string, rowBased bool, onlyValidate b return err } } else if fileType == NumpyFileExt { - file, err := os.Open(filePath) + file, err := p.chunkManager.Reader(filePath) if err != nil { log.Error("imprort error: "+err.Error(), zap.String("filePath", filePath)) return err } defer file.Close() + var id storage.FieldID + for _, field := range p.collectionSchema.Fields { + if field.GetName() == fileName { + id = field.GetFieldID() + } + } // the numpy parser return a storage.FieldData, here construct a map[string]storage.FieldData to combine flushFunc := func(field storage.FieldData) error { - fields := make(map[string]storage.FieldData) - fields[fileName] = field + fields := make(map[storage.FieldID]storage.FieldData) + fields[id] = field combineFunc(fields) return nil } @@ -325,7 +333,7 @@ func (p *ImportWrapper) appendFunc(schema *schemapb.FieldSchema) func(src storag arr.NumRows[0]++ return nil } - case schemapb.DataType_String: + case schemapb.DataType_String, schemapb.DataType_VarChar: return func(src storage.FieldData, n int, target storage.FieldData) error { arr := target.(*storage.StringFieldData) arr.Data = append(arr.Data, src.GetRow(n).(string)) @@ -336,7 +344,7 @@ func (p *ImportWrapper) appendFunc(schema *schemapb.FieldSchema) func(src storag } } -func (p *ImportWrapper) splitFieldsData(fieldsData map[string]storage.FieldData, files []string) error { +func (p *ImportWrapper) splitFieldsData(fieldsData map[storage.FieldID]storage.FieldData, files []string) error { if len(fieldsData) == 0 { return errors.New("imprort error: fields data is empty") } @@ -347,7 +355,7 @@ func (p *ImportWrapper) splitFieldsData(fieldsData map[string]storage.FieldData, if schema.GetIsPrimaryKey() { primaryKey = schema } else { - _, ok := fieldsData[schema.GetName()] + _, ok := fieldsData[schema.GetFieldID()] if !ok { return errors.New("imprort error: field " + schema.GetName() + " not provided") } @@ -363,7 +371,7 @@ func (p *ImportWrapper) splitFieldsData(fieldsData map[string]storage.FieldData, break } - primaryData, ok := fieldsData[primaryKey.GetName()] + primaryData, ok := fieldsData[primaryKey.GetFieldID()] if !ok { // generate auto id for primary key if primaryKey.GetAutoID() { @@ -383,7 +391,7 @@ func (p *ImportWrapper) splitFieldsData(fieldsData map[string]storage.FieldData, } // prepare segemnts - segmentsData := make([]map[string]storage.FieldData, 0, p.shardNum) + segmentsData := make([]map[storage.FieldID]storage.FieldData, 0, p.shardNum) for i := 0; i < int(p.shardNum); i++ { segmentData := initSegmentData(p.collectionSchema) if segmentData == nil { @@ -412,8 +420,8 @@ func (p *ImportWrapper) splitFieldsData(fieldsData map[string]storage.FieldData, for k := 0; k < len(p.collectionSchema.Fields); k++ { schema := p.collectionSchema.Fields[k] - srcData := fieldsData[schema.GetName()] - targetData := segmentsData[shard][schema.GetName()] + srcData := fieldsData[schema.GetFieldID()] + targetData := segmentsData[shard][schema.GetFieldID()] appendFunc := appendFunctions[schema.GetName()] err := appendFunc(srcData, i, targetData) if err != nil { diff --git a/internal/util/importutil/import_wrapper_test.go b/internal/util/importutil/import_wrapper_test.go index b9d4d40359..c8bf7f37fd 100644 --- a/internal/util/importutil/import_wrapper_test.go +++ b/internal/util/importutil/import_wrapper_test.go @@ -1,18 +1,23 @@ package importutil import ( + "bufio" + "bytes" "context" "encoding/json" - "os" "strconv" "testing" + "github.com/stretchr/testify/assert" + "go.uber.org/zap" + "github.com/milvus-io/milvus/internal/common" + "github.com/milvus-io/milvus/internal/log" "github.com/milvus-io/milvus/internal/proto/commonpb" "github.com/milvus-io/milvus/internal/proto/schemapb" "github.com/milvus-io/milvus/internal/storage" + "github.com/milvus-io/milvus/internal/util/dependency" "github.com/milvus-io/milvus/internal/util/timerecord" - "github.com/stretchr/testify/assert" ) const ( @@ -20,8 +25,11 @@ const ( ) func Test_NewImportWrapper(t *testing.T) { + f := dependency.NewDefaultFactory(true) ctx := context.Background() - wrapper := NewImportWrapper(ctx, nil, 2, 1, nil, nil) + cm, err := f.NewVectorStorageChunkManager(ctx) + assert.NoError(t, err) + wrapper := NewImportWrapper(ctx, nil, 2, 1, nil, cm, nil) assert.Nil(t, wrapper) schema := &schemapb.CollectionSchema{ @@ -39,28 +47,18 @@ func Test_NewImportWrapper(t *testing.T) { Description: "int64", DataType: schemapb.DataType_Int64, }) - wrapper = NewImportWrapper(ctx, schema, 2, 1, nil, nil) + wrapper = NewImportWrapper(ctx, schema, 2, 1, nil, cm, nil) assert.NotNil(t, wrapper) - err := wrapper.Cancel() + err = wrapper.Cancel() assert.Nil(t, err) } -func saveFile(t *testing.T, filePath string, content []byte) *os.File { - fp, err := os.Create(filePath) - assert.Nil(t, err) - - _, err = fp.Write(content) - assert.Nil(t, err) - - return fp -} - func Test_ImportRowBased(t *testing.T) { + f := dependency.NewDefaultFactory(true) ctx := context.Background() - err := os.MkdirAll(TempFilesPath, os.ModePerm) - assert.Nil(t, err) - defer os.RemoveAll(TempFilesPath) + cm, err := f.NewVectorStorageChunkManager(ctx) + assert.NoError(t, err) idAllocator := newIDAllocator(ctx, t) @@ -75,11 +73,12 @@ func Test_ImportRowBased(t *testing.T) { }`) filePath := TempFilesPath + "rows_1.json" - fp1 := saveFile(t, filePath, content) - defer fp1.Close() + err = cm.Write(filePath, content) + assert.NoError(t, err) + defer cm.RemoveWithPrefix("") rowCount := 0 - flushFunc := func(fields map[string]storage.FieldData) error { + flushFunc := func(fields map[storage.FieldID]storage.FieldData) error { count := 0 for _, data := range fields { assert.Less(t, 0, data.RowNum()) @@ -94,7 +93,7 @@ func Test_ImportRowBased(t *testing.T) { } // success case - wrapper := NewImportWrapper(ctx, sampleSchema(), 2, 1, idAllocator, flushFunc) + wrapper := NewImportWrapper(ctx, sampleSchema(), 2, 1, idAllocator, cm, flushFunc) files := make([]string, 0) files = append(files, filePath) err = wrapper.Import(files, true, false) @@ -109,10 +108,10 @@ func Test_ImportRowBased(t *testing.T) { }`) filePath = TempFilesPath + "rows_2.json" - fp2 := saveFile(t, filePath, content) - defer fp2.Close() + err = cm.Write(filePath, content) + assert.NoError(t, err) - wrapper = NewImportWrapper(ctx, sampleSchema(), 2, 1, idAllocator, flushFunc) + wrapper = NewImportWrapper(ctx, sampleSchema(), 2, 1, idAllocator, cm, flushFunc) files = make([]string, 0) files = append(files, filePath) err = wrapper.Import(files, true, false) @@ -127,10 +126,11 @@ func Test_ImportRowBased(t *testing.T) { } func Test_ImportColumnBased_json(t *testing.T) { + f := dependency.NewDefaultFactory(true) ctx := context.Background() - err := os.MkdirAll(TempFilesPath, os.ModePerm) - assert.Nil(t, err) - defer os.RemoveAll(TempFilesPath) + cm, err := f.NewVectorStorageChunkManager(ctx) + assert.NoError(t, err) + defer cm.RemoveWithPrefix("") idAllocator := newIDAllocator(ctx, t) @@ -160,11 +160,11 @@ func Test_ImportColumnBased_json(t *testing.T) { }`) filePath := TempFilesPath + "columns_1.json" - fp1 := saveFile(t, filePath, content) - defer fp1.Close() + err = cm.Write(filePath, content) + assert.NoError(t, err) rowCount := 0 - flushFunc := func(fields map[string]storage.FieldData) error { + flushFunc := func(fields map[storage.FieldID]storage.FieldData) error { count := 0 for _, data := range fields { assert.Less(t, 0, data.RowNum()) @@ -179,7 +179,7 @@ func Test_ImportColumnBased_json(t *testing.T) { } // success case - wrapper := NewImportWrapper(ctx, sampleSchema(), 2, 1, idAllocator, flushFunc) + wrapper := NewImportWrapper(ctx, sampleSchema(), 2, 1, idAllocator, cm, flushFunc) files := make([]string, 0) files = append(files, filePath) err = wrapper.Import(files, false, false) @@ -192,10 +192,10 @@ func Test_ImportColumnBased_json(t *testing.T) { }`) filePath = TempFilesPath + "rows_2.json" - fp2 := saveFile(t, filePath, content) - defer fp2.Close() + err = cm.Write(filePath, content) + assert.NoError(t, err) - wrapper = NewImportWrapper(ctx, sampleSchema(), 2, 1, idAllocator, flushFunc) + wrapper = NewImportWrapper(ctx, sampleSchema(), 2, 1, idAllocator, cm, flushFunc) files = make([]string, 0) files = append(files, filePath) err = wrapper.Import(files, false, false) @@ -209,10 +209,11 @@ func Test_ImportColumnBased_json(t *testing.T) { } func Test_ImportColumnBased_numpy(t *testing.T) { + f := dependency.NewDefaultFactory(true) ctx := context.Background() - err := os.MkdirAll(TempFilesPath, os.ModePerm) - assert.Nil(t, err) - defer os.RemoveAll(TempFilesPath) + cm, err := f.NewVectorStorageChunkManager(ctx) + assert.NoError(t, err) + defer cm.RemoveWithPrefix("") idAllocator := newIDAllocator(ctx, t) @@ -230,24 +231,30 @@ func Test_ImportColumnBased_numpy(t *testing.T) { files := make([]string, 0) filePath := TempFilesPath + "scalar_fields.json" - fp1 := saveFile(t, filePath, content) - fp1.Close() + err = cm.Write(filePath, content) + assert.NoError(t, err) files = append(files, filePath) filePath = TempFilesPath + "field_binary_vector.npy" bin := [][2]uint8{{1, 2}, {3, 4}, {5, 6}, {7, 8}, {9, 10}} - err = CreateNumpyFile(filePath, bin) + content, err = CreateNumpyData(bin) assert.Nil(t, err) + log.Debug("content", zap.Any("c", content)) + err = cm.Write(filePath, content) + assert.NoError(t, err) files = append(files, filePath) filePath = TempFilesPath + "field_float_vector.npy" flo := [][4]float32{{1, 2, 3, 4}, {3, 4, 5, 6}, {5, 6, 7, 8}, {7, 8, 9, 10}, {9, 10, 11, 12}} - err = CreateNumpyFile(filePath, flo) + content, err = CreateNumpyData(flo) assert.Nil(t, err) + log.Debug("content", zap.Any("c", content)) + err = cm.Write(filePath, content) + assert.NoError(t, err) files = append(files, filePath) rowCount := 0 - flushFunc := func(fields map[string]storage.FieldData) error { + flushFunc := func(fields map[storage.FieldID]storage.FieldData) error { count := 0 for _, data := range fields { assert.Less(t, 0, data.RowNum()) @@ -262,7 +269,7 @@ func Test_ImportColumnBased_numpy(t *testing.T) { } // success case - wrapper := NewImportWrapper(ctx, sampleSchema(), 2, 1, idAllocator, flushFunc) + wrapper := NewImportWrapper(ctx, sampleSchema(), 2, 1, idAllocator, cm, flushFunc) err = wrapper.Import(files, false, false) assert.Nil(t, err) @@ -274,10 +281,10 @@ func Test_ImportColumnBased_numpy(t *testing.T) { }`) filePath = TempFilesPath + "rows_2.json" - fp2 := saveFile(t, filePath, content) - defer fp2.Close() + err = cm.Write(filePath, content) + assert.NoError(t, err) - wrapper = NewImportWrapper(ctx, sampleSchema(), 2, 1, idAllocator, flushFunc) + wrapper = NewImportWrapper(ctx, sampleSchema(), 2, 1, idAllocator, cm, flushFunc) files = make([]string, 0) files = append(files, filePath) err = wrapper.Import(files, false, false) @@ -321,10 +328,11 @@ func perfSchema(dim int) *schemapb.CollectionSchema { } func Test_ImportRowBased_perf(t *testing.T) { + f := dependency.NewDefaultFactory(true) ctx := context.Background() - err := os.MkdirAll(TempFilesPath, os.ModePerm) - assert.Nil(t, err) - defer os.RemoveAll(TempFilesPath) + cm, err := f.NewVectorStorageChunkManager(ctx) + assert.NoError(t, err) + defer cm.RemoveWithPrefix("") idAllocator := newIDAllocator(ctx, t) @@ -365,19 +373,22 @@ func Test_ImportRowBased_perf(t *testing.T) { // generate a json file filePath := TempFilesPath + "row_perf.json" func() { - fp, err := os.Create(filePath) - assert.Nil(t, err) - defer fp.Close() + var b bytes.Buffer + bw := bufio.NewWriter(&b) - encoder := json.NewEncoder(fp) + encoder := json.NewEncoder(bw) err = encoder.Encode(entities) assert.Nil(t, err) + err = bw.Flush() + assert.NoError(t, err) + err = cm.Write(filePath, b.Bytes()) + assert.NoError(t, err) }() tr.Record("generate large json file " + filePath) // parse the json file parseCount := 0 - flushFunc := func(fields map[string]storage.FieldData) error { + flushFunc := func(fields map[storage.FieldID]storage.FieldData) error { count := 0 for _, data := range fields { assert.Less(t, 0, data.RowNum()) @@ -393,7 +404,7 @@ func Test_ImportRowBased_perf(t *testing.T) { schema := perfSchema(dim) - wrapper := NewImportWrapper(ctx, schema, int32(shardNum), int32(segmentSize), idAllocator, flushFunc) + wrapper := NewImportWrapper(ctx, schema, int32(shardNum), int64(segmentSize), idAllocator, cm, flushFunc) files := make([]string, 0) files = append(files, filePath) err = wrapper.Import(files, true, false) @@ -404,10 +415,11 @@ func Test_ImportRowBased_perf(t *testing.T) { } func Test_ImportColumnBased_perf(t *testing.T) { + f := dependency.NewDefaultFactory(true) ctx := context.Background() - err := os.MkdirAll(TempFilesPath, os.ModePerm) - assert.Nil(t, err) - defer os.RemoveAll(TempFilesPath) + cm, err := f.NewVectorStorageChunkManager(ctx) + assert.NoError(t, err) + defer cm.RemoveWithPrefix("") idAllocator := newIDAllocator(ctx, t) @@ -449,15 +461,17 @@ func Test_ImportColumnBased_perf(t *testing.T) { // generate json files saveFileFunc := func(filePath string, data interface{}) error { - fp, err := os.Create(filePath) - if err != nil { - return err - } - defer fp.Close() + var b bytes.Buffer + bw := bufio.NewWriter(&b) - encoder := json.NewEncoder(fp) + encoder := json.NewEncoder(bw) err = encoder.Encode(data) - return err + assert.Nil(t, err) + err = bw.Flush() + assert.NoError(t, err) + err = cm.Write(filePath, b.Bytes()) + assert.NoError(t, err) + return nil } filePath1 := TempFilesPath + "ids.json" @@ -472,7 +486,7 @@ func Test_ImportColumnBased_perf(t *testing.T) { // parse the json file parseCount := 0 - flushFunc := func(fields map[string]storage.FieldData) error { + flushFunc := func(fields map[storage.FieldID]storage.FieldData) error { count := 0 for _, data := range fields { assert.Less(t, 0, data.RowNum()) @@ -488,7 +502,7 @@ func Test_ImportColumnBased_perf(t *testing.T) { schema := perfSchema(dim) - wrapper := NewImportWrapper(ctx, schema, int32(shardNum), int32(segmentSize), idAllocator, flushFunc) + wrapper := NewImportWrapper(ctx, schema, int32(shardNum), int64(segmentSize), idAllocator, cm, flushFunc) files := make([]string, 0) files = append(files, filePath1) files = append(files, filePath2) diff --git a/internal/util/importutil/json_handler.go b/internal/util/importutil/json_handler.go index 178f5a6cf5..83e1a2dc35 100644 --- a/internal/util/importutil/json_handler.go +++ b/internal/util/importutil/json_handler.go @@ -5,22 +5,23 @@ import ( "fmt" "strconv" + "go.uber.org/zap" + "github.com/milvus-io/milvus/internal/allocator" "github.com/milvus-io/milvus/internal/log" "github.com/milvus-io/milvus/internal/proto/schemapb" "github.com/milvus-io/milvus/internal/storage" "github.com/milvus-io/milvus/internal/util/typeutil" - "go.uber.org/zap" ) // interface to process rows data type JSONRowHandler interface { - Handle(rows []map[string]interface{}) error + Handle(rows []map[storage.FieldID]interface{}) error } // interface to process column data type JSONColumnHandler interface { - Handle(columns map[string][]interface{}) error + Handle(columns map[storage.FieldID][]interface{}) error } // method to get dimension of vecotor field @@ -49,7 +50,7 @@ type Validator struct { } // method to construct valiator functions -func initValidators(collectionSchema *schemapb.CollectionSchema, validators map[string]*Validator) error { +func initValidators(collectionSchema *schemapb.CollectionSchema, validators map[storage.FieldID]*Validator) error { if collectionSchema == nil { return errors.New("collection schema is nil") } @@ -70,13 +71,13 @@ func initValidators(collectionSchema *schemapb.CollectionSchema, validators map[ for i := 0; i < len(collectionSchema.Fields); i++ { schema := collectionSchema.Fields[i] - validators[schema.GetName()] = &Validator{} - validators[schema.GetName()].primaryKey = schema.GetIsPrimaryKey() - validators[schema.GetName()].autoID = schema.GetAutoID() + validators[schema.GetFieldID()] = &Validator{} + validators[schema.GetFieldID()].primaryKey = schema.GetIsPrimaryKey() + validators[schema.GetFieldID()].autoID = schema.GetAutoID() switch schema.DataType { case schemapb.DataType_Bool: - validators[schema.GetName()].validateFunc = func(obj interface{}) error { + validators[schema.GetFieldID()].validateFunc = func(obj interface{}) error { switch obj.(type) { case bool: return nil @@ -87,55 +88,55 @@ func initValidators(collectionSchema *schemapb.CollectionSchema, validators map[ } } - validators[schema.GetName()].convertFunc = func(obj interface{}, field storage.FieldData) error { + validators[schema.GetFieldID()].convertFunc = func(obj interface{}, field storage.FieldData) error { value := obj.(bool) field.(*storage.BoolFieldData).Data = append(field.(*storage.BoolFieldData).Data, value) field.(*storage.BoolFieldData).NumRows[0]++ return nil } case schemapb.DataType_Float: - validators[schema.GetName()].validateFunc = numericValidator - validators[schema.GetName()].convertFunc = func(obj interface{}, field storage.FieldData) error { + validators[schema.GetFieldID()].validateFunc = numericValidator + validators[schema.GetFieldID()].convertFunc = func(obj interface{}, field storage.FieldData) error { value := float32(obj.(float64)) field.(*storage.FloatFieldData).Data = append(field.(*storage.FloatFieldData).Data, value) field.(*storage.FloatFieldData).NumRows[0]++ return nil } case schemapb.DataType_Double: - validators[schema.GetName()].validateFunc = numericValidator - validators[schema.GetName()].convertFunc = func(obj interface{}, field storage.FieldData) error { + validators[schema.GetFieldID()].validateFunc = numericValidator + validators[schema.GetFieldID()].convertFunc = func(obj interface{}, field storage.FieldData) error { value := obj.(float64) field.(*storage.DoubleFieldData).Data = append(field.(*storage.DoubleFieldData).Data, value) field.(*storage.DoubleFieldData).NumRows[0]++ return nil } case schemapb.DataType_Int8: - validators[schema.GetName()].validateFunc = numericValidator - validators[schema.GetName()].convertFunc = func(obj interface{}, field storage.FieldData) error { + validators[schema.GetFieldID()].validateFunc = numericValidator + validators[schema.GetFieldID()].convertFunc = func(obj interface{}, field storage.FieldData) error { value := int8(obj.(float64)) field.(*storage.Int8FieldData).Data = append(field.(*storage.Int8FieldData).Data, value) field.(*storage.Int8FieldData).NumRows[0]++ return nil } case schemapb.DataType_Int16: - validators[schema.GetName()].validateFunc = numericValidator - validators[schema.GetName()].convertFunc = func(obj interface{}, field storage.FieldData) error { + validators[schema.GetFieldID()].validateFunc = numericValidator + validators[schema.GetFieldID()].convertFunc = func(obj interface{}, field storage.FieldData) error { value := int16(obj.(float64)) field.(*storage.Int16FieldData).Data = append(field.(*storage.Int16FieldData).Data, value) field.(*storage.Int16FieldData).NumRows[0]++ return nil } case schemapb.DataType_Int32: - validators[schema.GetName()].validateFunc = numericValidator - validators[schema.GetName()].convertFunc = func(obj interface{}, field storage.FieldData) error { + validators[schema.GetFieldID()].validateFunc = numericValidator + validators[schema.GetFieldID()].convertFunc = func(obj interface{}, field storage.FieldData) error { value := int32(obj.(float64)) field.(*storage.Int32FieldData).Data = append(field.(*storage.Int32FieldData).Data, value) field.(*storage.Int32FieldData).NumRows[0]++ return nil } case schemapb.DataType_Int64: - validators[schema.GetName()].validateFunc = numericValidator - validators[schema.GetName()].convertFunc = func(obj interface{}, field storage.FieldData) error { + validators[schema.GetFieldID()].validateFunc = numericValidator + validators[schema.GetFieldID()].convertFunc = func(obj interface{}, field storage.FieldData) error { value := int64(obj.(float64)) field.(*storage.Int64FieldData).Data = append(field.(*storage.Int64FieldData).Data, value) field.(*storage.Int64FieldData).NumRows[0]++ @@ -146,9 +147,9 @@ func initValidators(collectionSchema *schemapb.CollectionSchema, validators map[ if err != nil { return err } - validators[schema.GetName()].dimension = dim + validators[schema.GetFieldID()].dimension = dim - validators[schema.GetName()].validateFunc = func(obj interface{}) error { + validators[schema.GetFieldID()].validateFunc = func(obj interface{}) error { switch vt := obj.(type) { case []interface{}: if len(vt)*8 != dim { @@ -175,7 +176,7 @@ func initValidators(collectionSchema *schemapb.CollectionSchema, validators map[ } } - validators[schema.GetName()].convertFunc = func(obj interface{}, field storage.FieldData) error { + validators[schema.GetFieldID()].convertFunc = func(obj interface{}, field storage.FieldData) error { arr := obj.([]interface{}) for i := 0; i < len(arr); i++ { value := byte(arr[i].(float64)) @@ -190,9 +191,9 @@ func initValidators(collectionSchema *schemapb.CollectionSchema, validators map[ if err != nil { return err } - validators[schema.GetName()].dimension = dim + validators[schema.GetFieldID()].dimension = dim - validators[schema.GetName()].validateFunc = func(obj interface{}) error { + validators[schema.GetFieldID()].validateFunc = func(obj interface{}) error { switch vt := obj.(type) { case []interface{}: if len(vt) != dim { @@ -213,7 +214,7 @@ func initValidators(collectionSchema *schemapb.CollectionSchema, validators map[ } } - validators[schema.GetName()].convertFunc = func(obj interface{}, field storage.FieldData) error { + validators[schema.GetFieldID()].convertFunc = func(obj interface{}, field storage.FieldData) error { arr := obj.([]interface{}) for i := 0; i < len(arr); i++ { value := float32(arr[i].(float64)) @@ -222,8 +223,8 @@ func initValidators(collectionSchema *schemapb.CollectionSchema, validators map[ field.(*storage.FloatVectorFieldData).NumRows[0]++ return nil } - case schemapb.DataType_String: - validators[schema.GetName()].validateFunc = func(obj interface{}) error { + case schemapb.DataType_String, schemapb.DataType_VarChar: + validators[schema.GetFieldID()].validateFunc = func(obj interface{}) error { switch obj.(type) { case string: return nil @@ -234,7 +235,7 @@ func initValidators(collectionSchema *schemapb.CollectionSchema, validators map[ } } - validators[schema.GetName()].convertFunc = func(obj interface{}, field storage.FieldData) error { + validators[schema.GetFieldID()].convertFunc = func(obj interface{}, field storage.FieldData) error { value := obj.(string) field.(*storage.StringFieldData).Data = append(field.(*storage.StringFieldData).Data, value) field.(*storage.StringFieldData).NumRows[0]++ @@ -250,14 +251,14 @@ func initValidators(collectionSchema *schemapb.CollectionSchema, validators map[ // row-based json format validator class type JSONRowValidator struct { - downstream JSONRowHandler // downstream processor, typically is a JSONRowComsumer - validators map[string]*Validator // validators for each field - rowCounter int64 // how many rows have been validated + downstream JSONRowHandler // downstream processor, typically is a JSONRowComsumer + validators map[storage.FieldID]*Validator // validators for each field + rowCounter int64 // how many rows have been validated } func NewJSONRowValidator(collectionSchema *schemapb.CollectionSchema, downstream JSONRowHandler) *JSONRowValidator { v := &JSONRowValidator{ - validators: make(map[string]*Validator), + validators: make(map[storage.FieldID]*Validator), downstream: downstream, rowCounter: 0, } @@ -270,7 +271,7 @@ func (v *JSONRowValidator) ValidateCount() int64 { return v.rowCounter } -func (v *JSONRowValidator) Handle(rows []map[string]interface{}) error { +func (v *JSONRowValidator) Handle(rows []map[storage.FieldID]interface{}) error { if v.validators == nil || len(v.validators) == 0 { return errors.New("JSON row validator is not initialized") } @@ -286,14 +287,14 @@ func (v *JSONRowValidator) Handle(rows []map[string]interface{}) error { for i := 0; i < len(rows); i++ { row := rows[i] - for name, validator := range v.validators { + for id, validator := range v.validators { if validator.primaryKey && validator.autoID { // auto-generated primary key, ignore continue } - value, ok := row[name] + value, ok := row[id] if !ok { - return errors.New("JSON row validator: field " + name + " missed at the row " + strconv.FormatInt(v.rowCounter+int64(i), 10)) + return errors.New("JSON row validator: fieldID " + strconv.FormatInt(id, 10) + " missed at the row " + strconv.FormatInt(v.rowCounter+int64(i), 10)) } if err := validator.validateFunc(value); err != nil { @@ -313,27 +314,27 @@ func (v *JSONRowValidator) Handle(rows []map[string]interface{}) error { // column-based json format validator class type JSONColumnValidator struct { - downstream JSONColumnHandler // downstream processor, typically is a JSONColumnComsumer - validators map[string]*Validator // validators for each field - rowCounter map[string]int64 // row count of each field + downstream JSONColumnHandler // downstream processor, typically is a JSONColumnComsumer + validators map[storage.FieldID]*Validator // validators for each field + rowCounter map[storage.FieldID]int64 // row count of each field } func NewJSONColumnValidator(schema *schemapb.CollectionSchema, downstream JSONColumnHandler) *JSONColumnValidator { v := &JSONColumnValidator{ - validators: make(map[string]*Validator), + validators: make(map[storage.FieldID]*Validator), downstream: downstream, - rowCounter: make(map[string]int64), + rowCounter: make(map[storage.FieldID]int64), } initValidators(schema, v.validators) return v } -func (v *JSONColumnValidator) ValidateCount() map[string]int64 { +func (v *JSONColumnValidator) ValidateCount() map[storage.FieldID]int64 { return v.rowCounter } -func (v *JSONColumnValidator) Handle(columns map[string][]interface{}) error { +func (v *JSONColumnValidator) Handle(columns map[storage.FieldID][]interface{}) error { if v.validators == nil || len(v.validators) == 0 { return errors.New("JSON column validator is not initialized") } @@ -346,7 +347,7 @@ func (v *JSONColumnValidator) Handle(columns map[string][]interface{}) error { if rowCount == -1 { rowCount = counter } else if rowCount != counter { - return errors.New("JSON column validator: the field " + k + " row count " + strconv.Itoa(int(counter)) + " is not equal to other fields " + strconv.Itoa(int(rowCount))) + return errors.New("JSON column validator: the field " + strconv.FormatInt(k, 10) + " row count " + strconv.Itoa(int(counter)) + " is not equal to other fields " + strconv.Itoa(int(rowCount))) } } @@ -383,74 +384,74 @@ func (v *JSONColumnValidator) Handle(columns map[string][]interface{}) error { // row-based json format consumer class type JSONRowConsumer struct { - collectionSchema *schemapb.CollectionSchema // collection schema - rowIDAllocator *allocator.IDAllocator // autoid allocator - validators map[string]*Validator // validators for each field - rowCounter int64 // how many rows have been consumed - shardNum int32 // sharding number of the collection - segmentsData []map[string]storage.FieldData // in-memory segments data - segmentSize int32 // maximum size of a segment in MB - primaryKey string // name of primary key + collectionSchema *schemapb.CollectionSchema // collection schema + rowIDAllocator *allocator.IDAllocator // autoid allocator + validators map[storage.FieldID]*Validator // validators for each field + rowCounter int64 // how many rows have been consumed + shardNum int32 // sharding number of the collection + segmentsData []map[storage.FieldID]storage.FieldData // in-memory segments data + segmentSize int64 // maximum size of a segment in MB + primaryKey storage.FieldID // name of primary key - callFlushFunc func(fields map[string]storage.FieldData) error // call back function to flush segment + callFlushFunc func(fields map[storage.FieldID]storage.FieldData) error // call back function to flush segment } -func initSegmentData(collectionSchema *schemapb.CollectionSchema) map[string]storage.FieldData { - segmentData := make(map[string]storage.FieldData) +func initSegmentData(collectionSchema *schemapb.CollectionSchema) map[storage.FieldID]storage.FieldData { + segmentData := make(map[storage.FieldID]storage.FieldData) for i := 0; i < len(collectionSchema.Fields); i++ { schema := collectionSchema.Fields[i] switch schema.DataType { case schemapb.DataType_Bool: - segmentData[schema.GetName()] = &storage.BoolFieldData{ + segmentData[schema.GetFieldID()] = &storage.BoolFieldData{ Data: make([]bool, 0), NumRows: []int64{0}, } case schemapb.DataType_Float: - segmentData[schema.GetName()] = &storage.FloatFieldData{ + segmentData[schema.GetFieldID()] = &storage.FloatFieldData{ Data: make([]float32, 0), NumRows: []int64{0}, } case schemapb.DataType_Double: - segmentData[schema.GetName()] = &storage.DoubleFieldData{ + segmentData[schema.GetFieldID()] = &storage.DoubleFieldData{ Data: make([]float64, 0), NumRows: []int64{0}, } case schemapb.DataType_Int8: - segmentData[schema.GetName()] = &storage.Int8FieldData{ + segmentData[schema.GetFieldID()] = &storage.Int8FieldData{ Data: make([]int8, 0), NumRows: []int64{0}, } case schemapb.DataType_Int16: - segmentData[schema.GetName()] = &storage.Int16FieldData{ + segmentData[schema.GetFieldID()] = &storage.Int16FieldData{ Data: make([]int16, 0), NumRows: []int64{0}, } case schemapb.DataType_Int32: - segmentData[schema.GetName()] = &storage.Int32FieldData{ + segmentData[schema.GetFieldID()] = &storage.Int32FieldData{ Data: make([]int32, 0), NumRows: []int64{0}, } case schemapb.DataType_Int64: - segmentData[schema.GetName()] = &storage.Int64FieldData{ + segmentData[schema.GetFieldID()] = &storage.Int64FieldData{ Data: make([]int64, 0), NumRows: []int64{0}, } case schemapb.DataType_BinaryVector: dim, _ := getFieldDimension(schema) - segmentData[schema.GetName()] = &storage.BinaryVectorFieldData{ + segmentData[schema.GetFieldID()] = &storage.BinaryVectorFieldData{ Data: make([]byte, 0), NumRows: []int64{0}, Dim: dim, } case schemapb.DataType_FloatVector: dim, _ := getFieldDimension(schema) - segmentData[schema.GetName()] = &storage.FloatVectorFieldData{ + segmentData[schema.GetFieldID()] = &storage.FloatVectorFieldData{ Data: make([]float32, 0), NumRows: []int64{0}, Dim: dim, } - case schemapb.DataType_String: - segmentData[schema.GetName()] = &storage.StringFieldData{ + case schemapb.DataType_String, schemapb.DataType_VarChar: + segmentData[schema.GetFieldID()] = &storage.StringFieldData{ Data: make([]string, 0), NumRows: []int64{0}, } @@ -463,8 +464,8 @@ func initSegmentData(collectionSchema *schemapb.CollectionSchema) map[string]sto return segmentData } -func NewJSONRowConsumer(collectionSchema *schemapb.CollectionSchema, idAlloc *allocator.IDAllocator, shardNum int32, segmentSize int32, - flushFunc func(fields map[string]storage.FieldData) error) *JSONRowConsumer { +func NewJSONRowConsumer(collectionSchema *schemapb.CollectionSchema, idAlloc *allocator.IDAllocator, shardNum int32, segmentSize int64, + flushFunc func(fields map[storage.FieldID]storage.FieldData) error) *JSONRowConsumer { if collectionSchema == nil { log.Error("JSON row consumer: collection schema is nil") return nil @@ -473,16 +474,17 @@ func NewJSONRowConsumer(collectionSchema *schemapb.CollectionSchema, idAlloc *al v := &JSONRowConsumer{ collectionSchema: collectionSchema, rowIDAllocator: idAlloc, - validators: make(map[string]*Validator), + validators: make(map[storage.FieldID]*Validator), shardNum: shardNum, segmentSize: segmentSize, rowCounter: 0, + primaryKey: -1, callFlushFunc: flushFunc, } initValidators(collectionSchema, v.validators) - v.segmentsData = make([]map[string]storage.FieldData, 0, shardNum) + v.segmentsData = make([]map[storage.FieldID]storage.FieldData, 0, shardNum) for i := 0; i < int(shardNum); i++ { segmentData := initSegmentData(collectionSchema) if segmentData == nil { @@ -494,12 +496,12 @@ func NewJSONRowConsumer(collectionSchema *schemapb.CollectionSchema, idAlloc *al for i := 0; i < len(collectionSchema.Fields); i++ { schema := collectionSchema.Fields[i] if schema.GetIsPrimaryKey() { - v.primaryKey = schema.GetName() + v.primaryKey = schema.GetFieldID() break } } // primary key not found - if v.primaryKey == "" { + if v.primaryKey == -1 { log.Error("JSON row consumer: collection schema has no primary key") return nil } @@ -544,7 +546,7 @@ func (v *JSONRowConsumer) flush(force bool) error { return nil } -func (v *JSONRowConsumer) Handle(rows []map[string]interface{}) error { +func (v *JSONRowConsumer) Handle(rows []map[storage.FieldID]interface{}) error { if v.validators == nil || len(v.validators) == 0 { return errors.New("JSON row consumer is not initialized") } @@ -614,23 +616,23 @@ func (v *JSONRowConsumer) Handle(rows []map[string]interface{}) error { // column-based json format consumer class type JSONColumnConsumer struct { - collectionSchema *schemapb.CollectionSchema // collection schema - validators map[string]*Validator // validators for each field - fieldsData map[string]storage.FieldData // in-memory fields data - primaryKey string // name of primary key + collectionSchema *schemapb.CollectionSchema // collection schema + validators map[storage.FieldID]*Validator // validators for each field + fieldsData map[storage.FieldID]storage.FieldData // in-memory fields data + primaryKey storage.FieldID // name of primary key - callFlushFunc func(fields map[string]storage.FieldData) error // call back function to flush segment + callFlushFunc func(fields map[storage.FieldID]storage.FieldData) error // call back function to flush segment } func NewJSONColumnConsumer(collectionSchema *schemapb.CollectionSchema, - flushFunc func(fields map[string]storage.FieldData) error) *JSONColumnConsumer { + flushFunc func(fields map[storage.FieldID]storage.FieldData) error) *JSONColumnConsumer { if collectionSchema == nil { return nil } v := &JSONColumnConsumer{ collectionSchema: collectionSchema, - validators: make(map[string]*Validator), + validators: make(map[storage.FieldID]*Validator), callFlushFunc: flushFunc, } initValidators(collectionSchema, v.validators) @@ -639,7 +641,7 @@ func NewJSONColumnConsumer(collectionSchema *schemapb.CollectionSchema, for i := 0; i < len(collectionSchema.Fields); i++ { schema := collectionSchema.Fields[i] if schema.GetIsPrimaryKey() { - v.primaryKey = schema.GetName() + v.primaryKey = schema.GetFieldID() break } } @@ -650,9 +652,9 @@ func NewJSONColumnConsumer(collectionSchema *schemapb.CollectionSchema, func (v *JSONColumnConsumer) flush() error { // check row count, should be equal rowCount := 0 - for name, field := range v.fieldsData { + for id, field := range v.fieldsData { // skip the autoid field - if name == v.primaryKey && v.validators[v.primaryKey].autoID { + if id == v.primaryKey && v.validators[v.primaryKey].autoID { continue } cnt := field.RowNum() @@ -665,7 +667,7 @@ func (v *JSONColumnConsumer) flush() error { if rowCount == 0 { rowCount = cnt } else if rowCount != cnt { - return errors.New("JSON column consumer: " + name + " row count " + strconv.Itoa(cnt) + " doesn't equal " + strconv.Itoa(rowCount)) + return errors.New("JSON column consumer: " + strconv.FormatInt(id, 10) + " row count " + strconv.Itoa(cnt) + " doesn't equal " + strconv.Itoa(rowCount)) } } @@ -678,7 +680,7 @@ func (v *JSONColumnConsumer) flush() error { return v.callFlushFunc(v.fieldsData) } -func (v *JSONColumnConsumer) Handle(columns map[string][]interface{}) error { +func (v *JSONColumnConsumer) Handle(columns map[storage.FieldID][]interface{}) error { if v.validators == nil || len(v.validators) == 0 { return errors.New("JSON column consumer is not initialized") } @@ -691,10 +693,10 @@ func (v *JSONColumnConsumer) Handle(columns map[string][]interface{}) error { } // consume columns data - for name, values := range columns { - validator, ok := v.validators[name] + for id, values := range columns { + validator, ok := v.validators[id] if !ok { - // not a valid field name + // not a valid field id break } @@ -705,8 +707,8 @@ func (v *JSONColumnConsumer) Handle(columns map[string][]interface{}) error { // convert and consume data for i := 0; i < len(values); i++ { - if err := validator.convertFunc(values[i], v.fieldsData[name]); err != nil { - return errors.New("JSON column consumer: " + err.Error() + " of field " + name) + if err := validator.convertFunc(values[i], v.fieldsData[id]); err != nil { + return errors.New("JSON column consumer: " + err.Error() + " of field " + strconv.FormatInt(id, 10)) } } } diff --git a/internal/util/importutil/json_handler_test.go b/internal/util/importutil/json_handler_test.go index 1d607477e7..d91da9fe05 100644 --- a/internal/util/importutil/json_handler_test.go +++ b/internal/util/importutil/json_handler_test.go @@ -5,12 +5,13 @@ import ( "strings" "testing" + "github.com/stretchr/testify/assert" + "github.com/milvus-io/milvus/internal/allocator" "github.com/milvus-io/milvus/internal/proto/commonpb" "github.com/milvus-io/milvus/internal/proto/rootcoordpb" "github.com/milvus-io/milvus/internal/proto/schemapb" "github.com/milvus-io/milvus/internal/storage" - "github.com/stretchr/testify/assert" ) type mockIDAllocator struct { @@ -68,17 +69,23 @@ func Test_GetFieldDimension(t *testing.T) { } func Test_InitValidators(t *testing.T) { - validators := make(map[string]*Validator) + validators := make(map[storage.FieldID]*Validator) err := initValidators(nil, validators) assert.NotNil(t, err) + schema := sampleSchema() // success case - err = initValidators(sampleSchema(), validators) + err = initValidators(schema, validators) assert.Nil(t, err) - assert.Equal(t, len(sampleSchema().Fields), len(validators)) + assert.Equal(t, len(schema.Fields), len(validators)) + name2ID := make(map[string]storage.FieldID) + for _, field := range schema.Fields { + name2ID[field.GetName()] = field.GetFieldID() + } checkFunc := func(funcName string, validVal interface{}, invalidVal interface{}) { - v, ok := validators[funcName] + id := name2ID[funcName] + v, ok := validators[id] assert.True(t, ok) err = v.validateFunc(validVal) assert.Nil(t, err) @@ -127,7 +134,7 @@ func Test_InitValidators(t *testing.T) { checkFunc("field_float_vector", validVal, invalidVal) // error cases - schema := &schemapb.CollectionSchema{ + schema = &schemapb.CollectionSchema{ Name: "schema", Description: "schema", AutoID: true, @@ -144,7 +151,7 @@ func Test_InitValidators(t *testing.T) { }, }) - validators = make(map[string]*Validator) + validators = make(map[storage.FieldID]*Validator) err = initValidators(schema, validators) assert.NotNil(t, err) @@ -308,7 +315,7 @@ func Test_JSONRowConsumer(t *testing.T) { var callTime int32 var totalCount int - consumeFunc := func(fields map[string]storage.FieldData) error { + consumeFunc := func(fields map[storage.FieldID]storage.FieldData) error { callTime++ rowCount := 0 for _, data := range fields { @@ -370,7 +377,7 @@ func Test_JSONColumnConsumer(t *testing.T) { callTime := 0 rowCount := 0 - consumeFunc := func(fields map[string]storage.FieldData) error { + consumeFunc := func(fields map[storage.FieldID]storage.FieldData) error { callTime++ for _, data := range fields { if rowCount == 0 { diff --git a/internal/util/importutil/json_parser.go b/internal/util/importutil/json_parser.go index cf0d67738f..bd0cb317a5 100644 --- a/internal/util/importutil/json_parser.go +++ b/internal/util/importutil/json_parser.go @@ -9,6 +9,7 @@ import ( "github.com/milvus-io/milvus/internal/log" "github.com/milvus-io/milvus/internal/proto/schemapb" + "github.com/milvus-io/milvus/internal/storage" ) const ( @@ -19,23 +20,27 @@ const ( ) type JSONParser struct { - ctx context.Context // for canceling parse process - bufSize int64 // max rows in a buffer - fields map[string]int64 // fields need to be parsed + ctx context.Context // for canceling parse process + bufSize int64 // max rows in a buffer + fields map[string]int64 // fields need to be parsed + name2FieldID map[string]storage.FieldID } // NewJSONParser helper function to create a JSONParser func NewJSONParser(ctx context.Context, collectionSchema *schemapb.CollectionSchema) *JSONParser { fields := make(map[string]int64) + name2FieldID := make(map[string]storage.FieldID) for i := 0; i < len(collectionSchema.Fields); i++ { schema := collectionSchema.Fields[i] fields[schema.GetName()] = 0 + name2FieldID[schema.GetName()] = schema.GetFieldID() } parser := &JSONParser{ - ctx: ctx, - bufSize: 4096, - fields: fields, + ctx: ctx, + bufSize: 4096, + fields: fields, + name2FieldID: name2FieldID, } return parser @@ -87,7 +92,7 @@ func (p *JSONParser) ParseRows(r io.Reader, handler JSONRowHandler) error { } // read buffer - buf := make([]map[string]interface{}, 0, BufferSize) + buf := make([]map[storage.FieldID]interface{}, 0, BufferSize) for dec.More() { var value interface{} if err := dec.Decode(&value); err != nil { @@ -101,7 +106,11 @@ func (p *JSONParser) ParseRows(r io.Reader, handler JSONRowHandler) error { return p.logError("JSON parse: invalid JSON format, each row should be a key-value map") } - row := value.(map[string]interface{}) + row := make(map[storage.FieldID]interface{}) + stringMap := value.(map[string]interface{}) + for k, v := range stringMap { + row[p.name2FieldID[k]] = v + } buf = append(buf, row) if len(buf) >= int(p.bufSize) { @@ -110,7 +119,7 @@ func (p *JSONParser) ParseRows(r io.Reader, handler JSONRowHandler) error { } // clear the buffer - buf = make([]map[string]interface{}, 0, BufferSize) + buf = make([]map[storage.FieldID]interface{}, 0, BufferSize) } } @@ -185,9 +194,10 @@ func (p *JSONParser) ParseColumns(r io.Reader, handler JSONColumnHandler) error return p.logError("JSON parse: invalid column-based JSON format, each field should begin with '['") } + id := p.name2FieldID[key] // read buffer - buf := make(map[string][]interface{}) - buf[key] = make([]interface{}, 0, BufferSize) + buf := make(map[storage.FieldID][]interface{}) + buf[id] = make([]interface{}, 0, BufferSize) for dec.More() { var value interface{} if err := dec.Decode(&value); err != nil { @@ -198,19 +208,19 @@ func (p *JSONParser) ParseColumns(r io.Reader, handler JSONColumnHandler) error continue } - buf[key] = append(buf[key], value) - if len(buf[key]) >= int(p.bufSize) { + buf[id] = append(buf[id], value) + if len(buf[id]) >= int(p.bufSize) { if err = handler.Handle(buf); err != nil { return p.logError(err.Error()) } // clear the buffer - buf[key] = make([]interface{}, 0, BufferSize) + buf[id] = make([]interface{}, 0, BufferSize) } } // some values in buffer not parsed, parse them - if len(buf[key]) > 0 { + if len(buf[id]) > 0 { if err = handler.Handle(buf); err != nil { return p.logError(err.Error()) } diff --git a/internal/util/importutil/numpy_adapter.go b/internal/util/importutil/numpy_adapter.go index 7e03f58275..399f885437 100644 --- a/internal/util/importutil/numpy_adapter.go +++ b/internal/util/importutil/numpy_adapter.go @@ -1,6 +1,7 @@ package importutil import ( + "bytes" "encoding/binary" "errors" "io" @@ -25,6 +26,16 @@ func CreateNumpyFile(path string, data interface{}) error { return nil } +func CreateNumpyData(data interface{}) ([]byte, error) { + buf := new(bytes.Buffer) + err := npyio.Write(buf, data) + if err != nil { + return nil, err + } + + return buf.Bytes(), nil +} + // a class to expand other numpy lib ability // we evaluate two go-numpy lins: github.com/kshedden/gonpy and github.com/sbinet/npyio // the npyio lib read data one by one, the performance is poor, we expand the read methods diff --git a/internal/util/importutil/numpy_parser_test.go b/internal/util/importutil/numpy_parser_test.go index 4bb1207486..84edd18e86 100644 --- a/internal/util/importutil/numpy_parser_test.go +++ b/internal/util/importutil/numpy_parser_test.go @@ -5,11 +5,12 @@ import ( "os" "testing" + "github.com/sbinet/npyio/npy" + "github.com/stretchr/testify/assert" + "github.com/milvus-io/milvus/internal/proto/schemapb" "github.com/milvus-io/milvus/internal/storage" "github.com/milvus-io/milvus/internal/util/timerecord" - "github.com/sbinet/npyio/npy" - "github.com/stretchr/testify/assert" ) func Test_NewNumpyParser(t *testing.T) {