diff --git a/client/milvusclient/collection.go b/client/milvusclient/collection.go index 94ad0830b9..5155ccfcc3 100644 --- a/client/milvusclient/collection.go +++ b/client/milvusclient/collection.go @@ -167,10 +167,6 @@ func (c *Client) AlterCollectionFieldProperty(ctx context.Context, option AlterC }) } -type GetCollectionOption interface { - Request() *milvuspb.GetCollectionStatisticsRequest -} - func (c *Client) GetCollectionStats(ctx context.Context, opt GetCollectionOption) (map[string]string, error) { var stats map[string]string err := c.callService(func(milvusService milvuspb.MilvusServiceClient) error { @@ -186,3 +182,14 @@ func (c *Client) GetCollectionStats(ctx context.Context, opt GetCollectionOption } return stats, nil } + +// AddCollectionField adds a field to a collection. +func (c *Client) AddCollectionField(ctx context.Context, opt AddCollectionFieldOption, callOpts ...grpc.CallOption) error { + req := opt.Request() + + err := c.callService(func(milvusService milvuspb.MilvusServiceClient) error { + resp, err := milvusService.AddCollectionField(ctx, req, callOpts...) + return merr.CheckRPCCall(resp, err) + }) + return err +} diff --git a/client/milvusclient/collection_example_test.go b/client/milvusclient/collection_example_test.go index 88786921f2..0a0de980f0 100644 --- a/client/milvusclient/collection_example_test.go +++ b/client/milvusclient/collection_example_test.go @@ -524,3 +524,28 @@ func ExampleClient_DropCollection() { // handle err } } + +func ExampleClient_AddCollectionField() { + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + milvusAddr := "127.0.0.1:19530" + + cli, err := milvusclient.New(ctx, &milvusclient.ClientConfig{ + Address: milvusAddr, + }) + if err != nil { + log.Fatal("failed to connect to milvus server: ", err.Error()) + } + + defer cli.Close(ctx) + + // the field to add + // must be nullable for now + newField := entity.NewField().WithName("new_field").WithDataType(entity.FieldTypeInt64).WithNullable(true) + + err = cli.AddCollectionField(ctx, milvusclient.NewAddCollectionFieldOption("customized_setup_2", newField)) + if err != nil { + // handle error + } +} diff --git a/client/milvusclient/collection_options.go b/client/milvusclient/collection_options.go index fc271aa1f2..ba994ace8e 100644 --- a/client/milvusclient/collection_options.go +++ b/client/milvusclient/collection_options.go @@ -385,6 +385,10 @@ func NewAlterCollectionFieldPropertiesOption(collectionName string, fieldName st } } +type GetCollectionOption interface { + Request() *milvuspb.GetCollectionStatisticsRequest +} + type getCollectionStatsOption struct { collectionName string } @@ -398,3 +402,27 @@ func (opt *getCollectionStatsOption) Request() *milvuspb.GetCollectionStatistics func NewGetCollectionStatsOption(collectionName string) *getCollectionStatsOption { return &getCollectionStatsOption{collectionName: collectionName} } + +type AddCollectionFieldOption interface { + Request() *milvuspb.AddCollectionFieldRequest +} + +type addCollectionFieldOption struct { + collectionName string + fieldSch *entity.Field +} + +func (c *addCollectionFieldOption) Request() *milvuspb.AddCollectionFieldRequest { + bs, _ := proto.Marshal(c.fieldSch.ProtoMessage()) + return &milvuspb.AddCollectionFieldRequest{ + CollectionName: c.collectionName, + Schema: bs, + } +} + +func NewAddCollectionFieldOption(collectionName string, field *entity.Field) *addCollectionFieldOption { + return &addCollectionFieldOption{ + collectionName: collectionName, + fieldSch: field, + } +} diff --git a/client/milvusclient/collection_test.go b/client/milvusclient/collection_test.go index 327467ad97..18d21699c1 100644 --- a/client/milvusclient/collection_test.go +++ b/client/milvusclient/collection_test.go @@ -408,6 +408,41 @@ func (s *CollectionSuite) TestGetCollectionStats() { }) } +func (s *CollectionSuite) TestAddCollectionField() { + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + s.Run("success", func() { + collName := fmt.Sprintf("coll_%s", s.randString(6)) + fieldName := fmt.Sprintf("field_%s", s.randString(6)) + s.mock.EXPECT().AddCollectionField(mock.Anything, mock.Anything).RunAndReturn(func(ctx context.Context, acfr *milvuspb.AddCollectionFieldRequest) (*commonpb.Status, error) { + fieldProto := &schemapb.FieldSchema{} + err := proto.Unmarshal(acfr.GetSchema(), fieldProto) + s.Require().NoError(err) + s.Equal(fieldName, fieldProto.GetName()) + s.Equal(schemapb.DataType_Int64, fieldProto.GetDataType()) + s.True(fieldProto.GetNullable()) + return merr.Success(), nil + }).Once() + + field := entity.NewField().WithName(fieldName).WithDataType(entity.FieldTypeInt64).WithNullable(true) + + err := s.client.AddCollectionField(ctx, NewAddCollectionFieldOption(collName, field)) + s.NoError(err) + }) + + s.Run("failure", func() { + collName := fmt.Sprintf("coll_%s", s.randString(6)) + fieldName := fmt.Sprintf("field_%s", s.randString(6)) + s.mock.EXPECT().AddCollectionField(mock.Anything, mock.Anything).Return(merr.Status(errors.New("mocked")), nil).Once() + + field := entity.NewField().WithName(fieldName).WithDataType(entity.FieldTypeInt64).WithNullable(true) + + err := s.client.AddCollectionField(ctx, NewAddCollectionFieldOption(collName, field)) + s.Error(err) + }) +} + func TestCollection(t *testing.T) { suite.Run(t, new(CollectionSuite)) }