diff --git a/client/entity/common.go b/client/entity/common.go index 2de5ee3918..ec794e3db1 100644 --- a/client/entity/common.go +++ b/client/entity/common.go @@ -29,4 +29,5 @@ const ( TANIMOTO MetricType = "TANIMOTO" SUBSTRUCTURE MetricType = "SUBSTRUCTURE" SUPERSTRUCTURE MetricType = "SUPERSTRUCTURE" + BM25 MetricType = "BM25" ) diff --git a/client/entity/field.go b/client/entity/field.go new file mode 100644 index 0000000000..d2765ae747 --- /dev/null +++ b/client/entity/field.go @@ -0,0 +1,401 @@ +// Licensed to the LF AI & Data foundation under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package entity + +import ( + "encoding/json" + "strconv" + + "github.com/cockroachdb/errors" + + "github.com/milvus-io/milvus-proto/go-api/v2/schemapb" +) + +// FieldType field data type alias type +// used in go:generate trick, DO NOT modify names & string +type FieldType int32 + +// Name returns field type name +func (t FieldType) Name() string { + switch t { + case FieldTypeBool: + return "Bool" + case FieldTypeInt8: + return "Int8" + case FieldTypeInt16: + return "Int16" + case FieldTypeInt32: + return "Int32" + case FieldTypeInt64: + return "Int64" + case FieldTypeFloat: + return "Float" + case FieldTypeDouble: + return "Double" + case FieldTypeString: + return "String" + case FieldTypeVarChar: + return "VarChar" + case FieldTypeArray: + return "Array" + case FieldTypeJSON: + return "JSON" + case FieldTypeBinaryVector: + return "BinaryVector" + case FieldTypeFloatVector: + return "FloatVector" + case FieldTypeFloat16Vector: + return "Float16Vector" + case FieldTypeBFloat16Vector: + return "BFloat16Vector" + default: + return "undefined" + } +} + +// String returns field type +func (t FieldType) String() string { + switch t { + case FieldTypeBool: + return "bool" + case FieldTypeInt8: + return "int8" + case FieldTypeInt16: + return "int16" + case FieldTypeInt32: + return "int32" + case FieldTypeInt64: + return "int64" + case FieldTypeFloat: + return "float32" + case FieldTypeDouble: + return "float64" + case FieldTypeString: + return "string" + case FieldTypeVarChar: + return "string" + case FieldTypeArray: + return "Array" + case FieldTypeJSON: + return "JSON" + case FieldTypeBinaryVector: + return "[]byte" + case FieldTypeFloatVector: + return "[]float32" + case FieldTypeFloat16Vector: + return "[]byte" + case FieldTypeBFloat16Vector: + return "[]byte" + default: + return "undefined" + } +} + +// PbFieldType represents FieldType corresponding schema pb type +func (t FieldType) PbFieldType() (string, string) { + switch t { + case FieldTypeBool: + return "Bool", "bool" + case FieldTypeInt8: + fallthrough + case FieldTypeInt16: + fallthrough + case FieldTypeInt32: + return "Int", "int32" + case FieldTypeInt64: + return "Long", "int64" + case FieldTypeFloat: + return "Float", "float32" + case FieldTypeDouble: + return "Double", "float64" + case FieldTypeString: + return "String", "string" + case FieldTypeVarChar: + return "VarChar", "string" + case FieldTypeJSON: + return "JSON", "JSON" + case FieldTypeBinaryVector: + return "[]byte", "" + case FieldTypeFloatVector: + return "[]float32", "" + case FieldTypeFloat16Vector: + return "[]byte", "" + case FieldTypeBFloat16Vector: + return "[]byte", "" + default: + return "undefined", "" + } +} + +// Match schema definition +const ( + // FieldTypeNone zero value place holder + FieldTypeNone FieldType = 0 // zero value place holder + // FieldTypeBool field type boolean + FieldTypeBool FieldType = 1 + // FieldTypeInt8 field type int8 + FieldTypeInt8 FieldType = 2 + // FieldTypeInt16 field type int16 + FieldTypeInt16 FieldType = 3 + // FieldTypeInt32 field type int32 + FieldTypeInt32 FieldType = 4 + // FieldTypeInt64 field type int64 + FieldTypeInt64 FieldType = 5 + // FieldTypeFloat field type float + FieldTypeFloat FieldType = 10 + // FieldTypeDouble field type double + FieldTypeDouble FieldType = 11 + // FieldTypeString field type string + FieldTypeString FieldType = 20 + // FieldTypeVarChar field type varchar + FieldTypeVarChar FieldType = 21 // variable-length strings with a specified maximum length + // FieldTypeArray field type Array + FieldTypeArray FieldType = 22 + // FieldTypeJSON field type JSON + FieldTypeJSON FieldType = 23 + // FieldTypeBinaryVector field type binary vector + FieldTypeBinaryVector FieldType = 100 + // FieldTypeFloatVector field type float vector + FieldTypeFloatVector FieldType = 101 + // FieldTypeBinaryVector field type float16 vector + FieldTypeFloat16Vector FieldType = 102 + // FieldTypeBinaryVector field type bf16 vector + FieldTypeBFloat16Vector FieldType = 103 + // FieldTypeBinaryVector field type sparse vector + FieldTypeSparseVector FieldType = 104 +) + +// Field represent field schema in milvus +type Field struct { + ID int64 // field id, generated when collection is created, input value is ignored + Name string // field name + PrimaryKey bool // is primary key + AutoID bool // is auto id + Description string + DataType FieldType + TypeParams map[string]string + IndexParams map[string]string + IsDynamic bool + IsPartitionKey bool + IsClusteringKey bool + ElementType FieldType +} + +// ProtoMessage generates corresponding FieldSchema +func (f *Field) ProtoMessage() *schemapb.FieldSchema { + return &schemapb.FieldSchema{ + FieldID: f.ID, + Name: f.Name, + Description: f.Description, + IsPrimaryKey: f.PrimaryKey, + AutoID: f.AutoID, + DataType: schemapb.DataType(f.DataType), + TypeParams: MapKvPairs(f.TypeParams), + IndexParams: MapKvPairs(f.IndexParams), + IsDynamic: f.IsDynamic, + IsPartitionKey: f.IsPartitionKey, + IsClusteringKey: f.IsClusteringKey, + ElementType: schemapb.DataType(f.ElementType), + } +} + +// NewField creates a new Field with map initialized. +func NewField() *Field { + return &Field{ + TypeParams: make(map[string]string), + IndexParams: make(map[string]string), + } +} + +func (f *Field) WithName(name string) *Field { + f.Name = name + return f +} + +func (f *Field) WithDescription(desc string) *Field { + f.Description = desc + return f +} + +func (f *Field) WithDataType(dataType FieldType) *Field { + f.DataType = dataType + return f +} + +func (f *Field) WithIsPrimaryKey(isPrimaryKey bool) *Field { + f.PrimaryKey = isPrimaryKey + return f +} + +func (f *Field) WithIsAutoID(isAutoID bool) *Field { + f.AutoID = isAutoID + return f +} + +func (f *Field) WithIsDynamic(isDynamic bool) *Field { + f.IsDynamic = isDynamic + return f +} + +func (f *Field) WithIsPartitionKey(isPartitionKey bool) *Field { + f.IsPartitionKey = isPartitionKey + return f +} + +func (f *Field) WithIsClusteringKey(isClusteringKey bool) *Field { + f.IsClusteringKey = isClusteringKey + return f +} + +/* +func (f *Field) WithDefaultValueBool(defaultValue bool) *Field { + f.DefaultValue = &schemapb.ValueField{ + Data: &schemapb.ValueField_BoolData{ + BoolData: defaultValue, + }, + } + return f +} + +func (f *Field) WithDefaultValueInt(defaultValue int32) *Field { + f.DefaultValue = &schemapb.ValueField{ + Data: &schemapb.ValueField_IntData{ + IntData: defaultValue, + }, + } + return f +} + +func (f *Field) WithDefaultValueLong(defaultValue int64) *Field { + f.DefaultValue = &schemapb.ValueField{ + Data: &schemapb.ValueField_LongData{ + LongData: defaultValue, + }, + } + return f +} + +func (f *Field) WithDefaultValueFloat(defaultValue float32) *Field { + f.DefaultValue = &schemapb.ValueField{ + Data: &schemapb.ValueField_FloatData{ + FloatData: defaultValue, + }, + } + return f +} + +func (f *Field) WithDefaultValueDouble(defaultValue float64) *Field { + f.DefaultValue = &schemapb.ValueField{ + Data: &schemapb.ValueField_DoubleData{ + DoubleData: defaultValue, + }, + } + return f +} + +func (f *Field) WithDefaultValueString(defaultValue string) *Field { + f.DefaultValue = &schemapb.ValueField{ + Data: &schemapb.ValueField_StringData{ + StringData: defaultValue, + }, + } + return f +}*/ + +func (f *Field) WithTypeParams(key string, value string) *Field { + if f.TypeParams == nil { + f.TypeParams = make(map[string]string) + } + f.TypeParams[key] = value + return f +} + +func (f *Field) WithDim(dim int64) *Field { + if f.TypeParams == nil { + f.TypeParams = make(map[string]string) + } + f.TypeParams[TypeParamDim] = strconv.FormatInt(dim, 10) + return f +} + +func (f *Field) GetDim() (int64, error) { + dimStr, has := f.TypeParams[TypeParamDim] + if !has { + return -1, errors.New("field with no dim") + } + dim, err := strconv.ParseInt(dimStr, 10, 64) + if err != nil { + return -1, errors.Newf("field with bad format dim: %s", err.Error()) + } + return dim, nil +} + +func (f *Field) WithMaxLength(maxLen int64) *Field { + if f.TypeParams == nil { + f.TypeParams = make(map[string]string) + } + f.TypeParams[TypeParamMaxLength] = strconv.FormatInt(maxLen, 10) + return f +} + +func (f *Field) WithElementType(eleType FieldType) *Field { + f.ElementType = eleType + return f +} + +func (f *Field) WithMaxCapacity(maxCap int64) *Field { + if f.TypeParams == nil { + f.TypeParams = make(map[string]string) + } + f.TypeParams[TypeParamMaxCapacity] = strconv.FormatInt(maxCap, 10) + return f +} + +func (f *Field) WithEnableAnalyzer(enable bool) *Field { + if f.TypeParams == nil { + f.TypeParams = make(map[string]string) + } + f.TypeParams["enable_analyzer"] = strconv.FormatBool(enable) + return f +} + +func (f *Field) WithAnalyzerParams(params map[string]any) *Field { + if f.TypeParams == nil { + f.TypeParams = make(map[string]string) + } + bs, _ := json.Marshal(params) + f.TypeParams["analyzer_params"] = string(bs) + return f +} + +// ReadProto parses FieldSchema +func (f *Field) ReadProto(p *schemapb.FieldSchema) *Field { + f.ID = p.GetFieldID() + f.Name = p.GetName() + f.PrimaryKey = p.GetIsPrimaryKey() + f.AutoID = p.GetAutoID() + f.Description = p.GetDescription() + f.DataType = FieldType(p.GetDataType()) + f.TypeParams = KvPairsMap(p.GetTypeParams()) + f.IndexParams = KvPairsMap(p.GetIndexParams()) + f.IsDynamic = p.GetIsDynamic() + f.IsPartitionKey = p.GetIsPartitionKey() + f.IsClusteringKey = p.GetIsClusteringKey() + f.ElementType = FieldType(p.GetElementType()) + + return f +} diff --git a/client/entity/field_test.go b/client/entity/field_test.go new file mode 100644 index 0000000000..3528b36a2a --- /dev/null +++ b/client/entity/field_test.go @@ -0,0 +1,74 @@ +// Licensed to the LF AI & Data foundation under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package entity + +import ( + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestFieldSchema(t *testing.T) { + fields := []*Field{ + NewField().WithName("int_field").WithDataType(FieldTypeInt64).WithIsAutoID(true).WithIsPrimaryKey(true).WithDescription("int_field desc"), + NewField().WithName("string_field").WithDataType(FieldTypeString).WithIsAutoID(false).WithIsPrimaryKey(true).WithIsDynamic(false).WithTypeParams("max_len", "32").WithDescription("string_field desc"), + NewField().WithName("partition_key").WithDataType(FieldTypeInt32).WithIsPartitionKey(true), + NewField().WithName("array_field").WithDataType(FieldTypeArray).WithElementType(FieldTypeBool).WithMaxCapacity(128), + NewField().WithName("clustering_key").WithDataType(FieldTypeInt32).WithIsClusteringKey(true), + NewField().WithName("varchar_text").WithDataType(FieldTypeVarChar).WithMaxLength(65535).WithEnableAnalyzer(true).WithAnalyzerParams(map[string]any{}), + /* + NewField().WithName("default_value_bool").WithDataType(FieldTypeBool).WithDefaultValueBool(true), + NewField().WithName("default_value_int").WithDataType(FieldTypeInt32).WithDefaultValueInt(1), + NewField().WithName("default_value_long").WithDataType(FieldTypeInt64).WithDefaultValueLong(1), + NewField().WithName("default_value_float").WithDataType(FieldTypeFloat).WithDefaultValueFloat(1), + NewField().WithName("default_value_double").WithDataType(FieldTypeDouble).WithDefaultValueDouble(1), + NewField().WithName("default_value_string").WithDataType(FieldTypeString).WithDefaultValueString("a"),*/ + } + + for _, field := range fields { + fieldSchema := field.ProtoMessage() + assert.Equal(t, field.ID, fieldSchema.GetFieldID()) + assert.Equal(t, field.Name, fieldSchema.GetName()) + assert.EqualValues(t, field.DataType, fieldSchema.GetDataType()) + assert.Equal(t, field.AutoID, fieldSchema.GetAutoID()) + assert.Equal(t, field.PrimaryKey, fieldSchema.GetIsPrimaryKey()) + assert.Equal(t, field.IsPartitionKey, fieldSchema.GetIsPartitionKey()) + assert.Equal(t, field.IsClusteringKey, fieldSchema.GetIsClusteringKey()) + assert.Equal(t, field.IsDynamic, fieldSchema.GetIsDynamic()) + assert.Equal(t, field.Description, fieldSchema.GetDescription()) + assert.Equal(t, field.TypeParams, KvPairsMap(fieldSchema.GetTypeParams())) + assert.EqualValues(t, field.ElementType, fieldSchema.GetElementType()) + // marshal & unmarshal, still equals + nf := &Field{} + nf = nf.ReadProto(fieldSchema) + assert.Equal(t, field.ID, nf.ID) + assert.Equal(t, field.Name, nf.Name) + assert.EqualValues(t, field.DataType, nf.DataType) + assert.Equal(t, field.AutoID, nf.AutoID) + assert.Equal(t, field.PrimaryKey, nf.PrimaryKey) + assert.Equal(t, field.Description, nf.Description) + assert.Equal(t, field.IsDynamic, nf.IsDynamic) + assert.Equal(t, field.IsPartitionKey, nf.IsPartitionKey) + assert.Equal(t, field.IsClusteringKey, nf.IsClusteringKey) + assert.EqualValues(t, field.TypeParams, nf.TypeParams) + assert.EqualValues(t, field.ElementType, nf.ElementType) + } + + assert.NotPanics(t, func() { + (&Field{}).WithTypeParams("a", "b") + }) +} diff --git a/client/entity/field_type.go b/client/entity/field_type.go deleted file mode 100644 index 9c96aa20ae..0000000000 --- a/client/entity/field_type.go +++ /dev/null @@ -1,171 +0,0 @@ -// Licensed to the LF AI & Data foundation under one -// or more contributor license agreements. See the NOTICE file -// distributed with this work for additional information -// regarding copyright ownership. The ASF licenses this file -// to you under the Apache License, Version 2.0 (the -// "License"); you may not use this file except in compliance -// with the License. You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package entity - -// FieldType field data type alias type -// used in go:generate trick, DO NOT modify names & string -type FieldType int32 - -// Name returns field type name -func (t FieldType) Name() string { - switch t { - case FieldTypeBool: - return "Bool" - case FieldTypeInt8: - return "Int8" - case FieldTypeInt16: - return "Int16" - case FieldTypeInt32: - return "Int32" - case FieldTypeInt64: - return "Int64" - case FieldTypeFloat: - return "Float" - case FieldTypeDouble: - return "Double" - case FieldTypeString: - return "String" - case FieldTypeVarChar: - return "VarChar" - case FieldTypeArray: - return "Array" - case FieldTypeJSON: - return "JSON" - case FieldTypeBinaryVector: - return "BinaryVector" - case FieldTypeFloatVector: - return "FloatVector" - case FieldTypeFloat16Vector: - return "Float16Vector" - case FieldTypeBFloat16Vector: - return "BFloat16Vector" - default: - return "undefined" - } -} - -// String returns field type -func (t FieldType) String() string { - switch t { - case FieldTypeBool: - return "bool" - case FieldTypeInt8: - return "int8" - case FieldTypeInt16: - return "int16" - case FieldTypeInt32: - return "int32" - case FieldTypeInt64: - return "int64" - case FieldTypeFloat: - return "float32" - case FieldTypeDouble: - return "float64" - case FieldTypeString: - return "string" - case FieldTypeVarChar: - return "string" - case FieldTypeArray: - return "Array" - case FieldTypeJSON: - return "JSON" - case FieldTypeBinaryVector: - return "[]byte" - case FieldTypeFloatVector: - return "[]float32" - case FieldTypeFloat16Vector: - return "[]byte" - case FieldTypeBFloat16Vector: - return "[]byte" - default: - return "undefined" - } -} - -// PbFieldType represents FieldType corresponding schema pb type -func (t FieldType) PbFieldType() (string, string) { - switch t { - case FieldTypeBool: - return "Bool", "bool" - case FieldTypeInt8: - fallthrough - case FieldTypeInt16: - fallthrough - case FieldTypeInt32: - return "Int", "int32" - case FieldTypeInt64: - return "Long", "int64" - case FieldTypeFloat: - return "Float", "float32" - case FieldTypeDouble: - return "Double", "float64" - case FieldTypeString: - return "String", "string" - case FieldTypeVarChar: - return "VarChar", "string" - case FieldTypeJSON: - return "JSON", "JSON" - case FieldTypeBinaryVector: - return "[]byte", "" - case FieldTypeFloatVector: - return "[]float32", "" - case FieldTypeFloat16Vector: - return "[]byte", "" - case FieldTypeBFloat16Vector: - return "[]byte", "" - default: - return "undefined", "" - } -} - -// Match schema definition -const ( - // FieldTypeNone zero value place holder - FieldTypeNone FieldType = 0 // zero value place holder - // FieldTypeBool field type boolean - FieldTypeBool FieldType = 1 - // FieldTypeInt8 field type int8 - FieldTypeInt8 FieldType = 2 - // FieldTypeInt16 field type int16 - FieldTypeInt16 FieldType = 3 - // FieldTypeInt32 field type int32 - FieldTypeInt32 FieldType = 4 - // FieldTypeInt64 field type int64 - FieldTypeInt64 FieldType = 5 - // FieldTypeFloat field type float - FieldTypeFloat FieldType = 10 - // FieldTypeDouble field type double - FieldTypeDouble FieldType = 11 - // FieldTypeString field type string - FieldTypeString FieldType = 20 - // FieldTypeVarChar field type varchar - FieldTypeVarChar FieldType = 21 // variable-length strings with a specified maximum length - // FieldTypeArray field type Array - FieldTypeArray FieldType = 22 - // FieldTypeJSON field type JSON - FieldTypeJSON FieldType = 23 - // FieldTypeBinaryVector field type binary vector - FieldTypeBinaryVector FieldType = 100 - // FieldTypeFloatVector field type float vector - FieldTypeFloatVector FieldType = 101 - // FieldTypeBinaryVector field type float16 vector - FieldTypeFloat16Vector FieldType = 102 - // FieldTypeBinaryVector field type bf16 vector - FieldTypeBFloat16Vector FieldType = 103 - // FieldTypeBinaryVector field type sparse vector - FieldTypeSparseVector FieldType = 104 -) diff --git a/client/entity/function.go b/client/entity/function.go new file mode 100644 index 0000000000..5e8c85dcde --- /dev/null +++ b/client/entity/function.go @@ -0,0 +1,109 @@ +// Licensed to the LF AI & Data foundation under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package entity + +import ( + "fmt" + + "github.com/milvus-io/milvus-proto/go-api/v2/schemapb" +) + +type FunctionType = schemapb.FunctionType + +// provide package alias +const ( + FunctionTypeUnknown = schemapb.FunctionType_Unknown + FunctionTypeBM25 = schemapb.FunctionType_BM25 + FunctionTypeTextEmbedding = schemapb.FunctionType_TextEmbedding +) + +type Function struct { + Name string + Description string + Type FunctionType + + InputFieldNames []string + OutputFieldNames []string + Params map[string]string + + // ids shall be private + id int64 + inputFieldIDs []int64 + outputFieldIDs []int64 +} + +func NewFunction() *Function { + return &Function{ + Params: make(map[string]string), + } +} + +func (f *Function) WithName(name string) *Function { + f.Name = name + return f +} + +func (f *Function) WithInputFields(inputFields ...string) *Function { + f.InputFieldNames = inputFields + return f +} + +func (f *Function) WithOutputFields(outputFields ...string) *Function { + f.OutputFieldNames = outputFields + return f +} + +func (f *Function) WithType(funcType FunctionType) *Function { + f.Type = funcType + return f +} + +func (f *Function) WithParam(key string, value any) *Function { + f.Params[key] = fmt.Sprintf("%v", value) + return f +} + +// ProtoMessage returns corresponding schemapb.FunctionSchema +func (f *Function) ProtoMessage() *schemapb.FunctionSchema { + r := &schemapb.FunctionSchema{ + Name: f.Name, + Description: f.Description, + Type: f.Type, + InputFieldNames: f.InputFieldNames, + OutputFieldNames: f.OutputFieldNames, + Params: MapKvPairs(f.Params), + } + + return r +} + +// ReadProto parses proto Collection Schema +func (f *Function) ReadProto(p *schemapb.FunctionSchema) *Function { + f.Name = p.GetName() + f.Description = p.GetDescription() + f.Type = p.GetType() + + f.InputFieldNames = p.GetInputFieldNames() + f.OutputFieldNames = p.GetOutputFieldNames() + f.Params = KvPairsMap(p.GetParams()) + + f.id = p.GetId() + f.inputFieldIDs = p.GetInputFieldIds() + f.outputFieldIDs = p.GetOutputFieldIds() + + return f +} diff --git a/client/entity/function_test.go b/client/entity/function_test.go new file mode 100644 index 0000000000..6e73596bcd --- /dev/null +++ b/client/entity/function_test.go @@ -0,0 +1,48 @@ +// Licensed to the LF AI & Data foundation under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package entity + +import ( + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestFunctionSchema(t *testing.T) { + functions := []*Function{ + NewFunction().WithName("text_bm25_emb").WithType(FunctionTypeBM25).WithInputFields("a", "b").WithOutputFields("c").WithParam("key", "value"), + NewFunction().WithName("other_emb").WithType(FunctionTypeTextEmbedding).WithInputFields("c").WithOutputFields("b", "a"), + } + + for _, function := range functions { + funcSchema := function.ProtoMessage() + assert.Equal(t, function.Name, funcSchema.Name) + assert.Equal(t, function.Type, funcSchema.Type) + assert.Equal(t, function.InputFieldNames, funcSchema.InputFieldNames) + assert.Equal(t, function.OutputFieldNames, funcSchema.OutputFieldNames) + assert.Equal(t, function.Params, KvPairsMap(funcSchema.GetParams())) + + nf := NewFunction() + nf.ReadProto(funcSchema) + + assert.Equal(t, function.Name, nf.Name) + assert.Equal(t, function.Type, nf.Type) + assert.Equal(t, function.InputFieldNames, nf.InputFieldNames) + assert.Equal(t, function.OutputFieldNames, nf.OutputFieldNames) + assert.Equal(t, function.Params, nf.Params) + } +} diff --git a/client/entity/schema.go b/client/entity/schema.go index ab8878d7bb..501ec7b9b0 100644 --- a/client/entity/schema.go +++ b/client/entity/schema.go @@ -17,9 +17,7 @@ package entity import ( - "strconv" - - "github.com/cockroachdb/errors" + "github.com/samber/lo" "github.com/milvus-io/milvus-proto/go-api/v2/commonpb" "github.com/milvus-io/milvus-proto/go-api/v2/schemapb" @@ -62,6 +60,7 @@ type Schema struct { AutoID bool Fields []*Field EnableDynamicField bool + Functions []*Function pkField *Field } @@ -102,6 +101,11 @@ func (s *Schema) WithField(f *Field) *Schema { return s } +func (s *Schema) WithFunction(f *Function) *Schema { + s.Functions = append(s.Functions, f) + return s +} + // ProtoMessage returns corresponding server.CollectionSchema func (s *Schema) ProtoMessage() *schemapb.CollectionSchema { r := &schemapb.CollectionSchema{ @@ -110,10 +114,14 @@ func (s *Schema) ProtoMessage() *schemapb.CollectionSchema { AutoID: s.AutoID, EnableDynamicField: s.EnableDynamicField, } - r.Fields = make([]*schemapb.FieldSchema, 0, len(s.Fields)) - for _, field := range s.Fields { - r.Fields = append(r.Fields, field.ProtoMessage()) - } + r.Fields = lo.Map(s.Fields, func(field *Field, _ int) *schemapb.FieldSchema { + return field.ProtoMessage() + }) + + r.Functions = lo.Map(s.Functions, func(function *Function, _ int) *schemapb.FunctionSchema { + return function.ProtoMessage() + }) + return r } @@ -121,6 +129,8 @@ func (s *Schema) ProtoMessage() *schemapb.CollectionSchema { func (s *Schema) ReadProto(p *schemapb.CollectionSchema) *Schema { s.Description = p.GetDescription() s.CollectionName = p.GetName() + s.EnableDynamicField = p.GetEnableDynamicField() + // fields s.Fields = make([]*Field, 0, len(p.GetFields())) for _, fp := range p.GetFields() { field := NewField().ReadProto(fp) @@ -132,7 +142,11 @@ func (s *Schema) ReadProto(p *schemapb.CollectionSchema) *Schema { } s.Fields = append(s.Fields, field) } - s.EnableDynamicField = p.GetEnableDynamicField() + // functions + s.Functions = lo.Map(p.GetFunctions(), func(fn *schemapb.FunctionSchema, _ int) *Function { + return NewFunction().ReadProto(fn) + }) + return s } @@ -149,210 +163,6 @@ func (s *Schema) PKField() *Field { return s.pkField } -// Field represent field schema in milvus -type Field struct { - ID int64 // field id, generated when collection is created, input value is ignored - Name string // field name - PrimaryKey bool // is primary key - AutoID bool // is auto id - Description string - DataType FieldType - TypeParams map[string]string - IndexParams map[string]string - IsDynamic bool - IsPartitionKey bool - IsClusteringKey bool - ElementType FieldType -} - -// ProtoMessage generates corresponding FieldSchema -func (f *Field) ProtoMessage() *schemapb.FieldSchema { - return &schemapb.FieldSchema{ - FieldID: f.ID, - Name: f.Name, - Description: f.Description, - IsPrimaryKey: f.PrimaryKey, - AutoID: f.AutoID, - DataType: schemapb.DataType(f.DataType), - TypeParams: MapKvPairs(f.TypeParams), - IndexParams: MapKvPairs(f.IndexParams), - IsDynamic: f.IsDynamic, - IsPartitionKey: f.IsPartitionKey, - IsClusteringKey: f.IsClusteringKey, - ElementType: schemapb.DataType(f.ElementType), - } -} - -// NewField creates a new Field with map initialized. -func NewField() *Field { - return &Field{ - TypeParams: make(map[string]string), - IndexParams: make(map[string]string), - } -} - -func (f *Field) WithName(name string) *Field { - f.Name = name - return f -} - -func (f *Field) WithDescription(desc string) *Field { - f.Description = desc - return f -} - -func (f *Field) WithDataType(dataType FieldType) *Field { - f.DataType = dataType - return f -} - -func (f *Field) WithIsPrimaryKey(isPrimaryKey bool) *Field { - f.PrimaryKey = isPrimaryKey - return f -} - -func (f *Field) WithIsAutoID(isAutoID bool) *Field { - f.AutoID = isAutoID - return f -} - -func (f *Field) WithIsDynamic(isDynamic bool) *Field { - f.IsDynamic = isDynamic - return f -} - -func (f *Field) WithIsPartitionKey(isPartitionKey bool) *Field { - f.IsPartitionKey = isPartitionKey - return f -} - -func (f *Field) WithIsClusteringKey(isClusteringKey bool) *Field { - f.IsClusteringKey = isClusteringKey - return f -} - -/* -func (f *Field) WithDefaultValueBool(defaultValue bool) *Field { - f.DefaultValue = &schemapb.ValueField{ - Data: &schemapb.ValueField_BoolData{ - BoolData: defaultValue, - }, - } - return f -} - -func (f *Field) WithDefaultValueInt(defaultValue int32) *Field { - f.DefaultValue = &schemapb.ValueField{ - Data: &schemapb.ValueField_IntData{ - IntData: defaultValue, - }, - } - return f -} - -func (f *Field) WithDefaultValueLong(defaultValue int64) *Field { - f.DefaultValue = &schemapb.ValueField{ - Data: &schemapb.ValueField_LongData{ - LongData: defaultValue, - }, - } - return f -} - -func (f *Field) WithDefaultValueFloat(defaultValue float32) *Field { - f.DefaultValue = &schemapb.ValueField{ - Data: &schemapb.ValueField_FloatData{ - FloatData: defaultValue, - }, - } - return f -} - -func (f *Field) WithDefaultValueDouble(defaultValue float64) *Field { - f.DefaultValue = &schemapb.ValueField{ - Data: &schemapb.ValueField_DoubleData{ - DoubleData: defaultValue, - }, - } - return f -} - -func (f *Field) WithDefaultValueString(defaultValue string) *Field { - f.DefaultValue = &schemapb.ValueField{ - Data: &schemapb.ValueField_StringData{ - StringData: defaultValue, - }, - } - return f -}*/ - -func (f *Field) WithTypeParams(key string, value string) *Field { - if f.TypeParams == nil { - f.TypeParams = make(map[string]string) - } - f.TypeParams[key] = value - return f -} - -func (f *Field) WithDim(dim int64) *Field { - if f.TypeParams == nil { - f.TypeParams = make(map[string]string) - } - f.TypeParams[TypeParamDim] = strconv.FormatInt(dim, 10) - return f -} - -func (f *Field) GetDim() (int64, error) { - dimStr, has := f.TypeParams[TypeParamDim] - if !has { - return -1, errors.New("field with no dim") - } - dim, err := strconv.ParseInt(dimStr, 10, 64) - if err != nil { - return -1, errors.Newf("field with bad format dim: %s", err.Error()) - } - return dim, nil -} - -func (f *Field) WithMaxLength(maxLen int64) *Field { - if f.TypeParams == nil { - f.TypeParams = make(map[string]string) - } - f.TypeParams[TypeParamMaxLength] = strconv.FormatInt(maxLen, 10) - return f -} - -func (f *Field) WithElementType(eleType FieldType) *Field { - f.ElementType = eleType - return f -} - -func (f *Field) WithMaxCapacity(maxCap int64) *Field { - if f.TypeParams == nil { - f.TypeParams = make(map[string]string) - } - f.TypeParams[TypeParamMaxCapacity] = strconv.FormatInt(maxCap, 10) - return f -} - -// ReadProto parses FieldSchema -func (f *Field) ReadProto(p *schemapb.FieldSchema) *Field { - f.ID = p.GetFieldID() - f.Name = p.GetName() - f.PrimaryKey = p.GetIsPrimaryKey() - f.AutoID = p.GetAutoID() - f.Description = p.GetDescription() - f.DataType = FieldType(p.GetDataType()) - f.TypeParams = KvPairsMap(p.GetTypeParams()) - f.IndexParams = KvPairsMap(p.GetIndexParams()) - f.IsDynamic = p.GetIsDynamic() - f.IsPartitionKey = p.GetIsPartitionKey() - f.IsClusteringKey = p.GetIsClusteringKey() - f.ElementType = FieldType(p.GetElementType()) - - return f -} - // MapKvPairs converts map into commonpb.KeyValuePair slice func MapKvPairs(m map[string]string) []*commonpb.KeyValuePair { pairs := make([]*commonpb.KeyValuePair, 0, len(m)) diff --git a/client/entity/schema_test.go b/client/entity/schema_test.go index ed81c39e50..fb02476d98 100644 --- a/client/entity/schema_test.go +++ b/client/entity/schema_test.go @@ -37,56 +37,6 @@ func TestCL_CommonCL(t *testing.T) { } } -func TestFieldSchema(t *testing.T) { - fields := []*Field{ - NewField().WithName("int_field").WithDataType(FieldTypeInt64).WithIsAutoID(true).WithIsPrimaryKey(true).WithDescription("int_field desc"), - NewField().WithName("string_field").WithDataType(FieldTypeString).WithIsAutoID(false).WithIsPrimaryKey(true).WithIsDynamic(false).WithTypeParams("max_len", "32").WithDescription("string_field desc"), - NewField().WithName("partition_key").WithDataType(FieldTypeInt32).WithIsPartitionKey(true), - NewField().WithName("array_field").WithDataType(FieldTypeArray).WithElementType(FieldTypeBool).WithMaxCapacity(128), - NewField().WithName("clustering_key").WithDataType(FieldTypeInt32).WithIsClusteringKey(true), - /* - NewField().WithName("default_value_bool").WithDataType(FieldTypeBool).WithDefaultValueBool(true), - NewField().WithName("default_value_int").WithDataType(FieldTypeInt32).WithDefaultValueInt(1), - NewField().WithName("default_value_long").WithDataType(FieldTypeInt64).WithDefaultValueLong(1), - NewField().WithName("default_value_float").WithDataType(FieldTypeFloat).WithDefaultValueFloat(1), - NewField().WithName("default_value_double").WithDataType(FieldTypeDouble).WithDefaultValueDouble(1), - NewField().WithName("default_value_string").WithDataType(FieldTypeString).WithDefaultValueString("a"),*/ - } - - for _, field := range fields { - fieldSchema := field.ProtoMessage() - assert.Equal(t, field.ID, fieldSchema.GetFieldID()) - assert.Equal(t, field.Name, fieldSchema.GetName()) - assert.EqualValues(t, field.DataType, fieldSchema.GetDataType()) - assert.Equal(t, field.AutoID, fieldSchema.GetAutoID()) - assert.Equal(t, field.PrimaryKey, fieldSchema.GetIsPrimaryKey()) - assert.Equal(t, field.IsPartitionKey, fieldSchema.GetIsPartitionKey()) - assert.Equal(t, field.IsClusteringKey, fieldSchema.GetIsClusteringKey()) - assert.Equal(t, field.IsDynamic, fieldSchema.GetIsDynamic()) - assert.Equal(t, field.Description, fieldSchema.GetDescription()) - assert.Equal(t, field.TypeParams, KvPairsMap(fieldSchema.GetTypeParams())) - assert.EqualValues(t, field.ElementType, fieldSchema.GetElementType()) - // marshal & unmarshal, still equals - nf := &Field{} - nf = nf.ReadProto(fieldSchema) - assert.Equal(t, field.ID, nf.ID) - assert.Equal(t, field.Name, nf.Name) - assert.EqualValues(t, field.DataType, nf.DataType) - assert.Equal(t, field.AutoID, nf.AutoID) - assert.Equal(t, field.PrimaryKey, nf.PrimaryKey) - assert.Equal(t, field.Description, nf.Description) - assert.Equal(t, field.IsDynamic, nf.IsDynamic) - assert.Equal(t, field.IsPartitionKey, nf.IsPartitionKey) - assert.Equal(t, field.IsClusteringKey, nf.IsClusteringKey) - assert.EqualValues(t, field.TypeParams, nf.TypeParams) - assert.EqualValues(t, field.ElementType, nf.ElementType) - } - - assert.NotPanics(t, func() { - (&Field{}).WithTypeParams("a", "b") - }) -} - type SchemaSuite struct { suite.Suite } @@ -101,7 +51,8 @@ func (s *SchemaSuite) TestBasic() { "test_collection", NewSchema().WithName("test_collection_1").WithDescription("test_collection_1 desc").WithAutoID(false). WithField(NewField().WithName("ID").WithDataType(FieldTypeInt64).WithIsPrimaryKey(true)). - WithField(NewField().WithName("vector").WithDataType(FieldTypeFloatVector).WithDim(128)), + WithField(NewField().WithName("vector").WithDataType(FieldTypeFloatVector).WithDim(128)). + WithFunction(NewFunction()), "ID", }, { @@ -122,6 +73,7 @@ func (s *SchemaSuite) TestBasic() { s.Equal(sch.Description, p.GetDescription()) s.Equal(sch.EnableDynamicField, p.GetEnableDynamicField()) s.Equal(len(sch.Fields), len(p.GetFields())) + s.Equal(len(sch.Functions), len(p.GetFunctions())) nsch := &Schema{} nsch = nsch.ReadProto(p) @@ -130,6 +82,7 @@ func (s *SchemaSuite) TestBasic() { s.Equal(sch.Description, nsch.Description) s.Equal(sch.EnableDynamicField, nsch.EnableDynamicField) s.Equal(len(sch.Fields), len(nsch.Fields)) + s.Equal(len(sch.Functions), len(nsch.Functions)) s.Equal(c.pkName, sch.PKFieldName()) s.Equal(c.pkName, nsch.PKFieldName()) }) diff --git a/client/entity/vectors.go b/client/entity/vectors.go index 82f1fe5979..14f14bd5fd 100644 --- a/client/entity/vectors.go +++ b/client/entity/vectors.go @@ -104,3 +104,19 @@ func (bv BinaryVector) Serialize() []byte { func (bv BinaryVector) FieldType() FieldType { return FieldTypeBinaryVector } + +type Text string + +// Dim returns vector dimension. +func (t Text) Dim() int { + return 0 +} + +// entity.FieldType returns coresponding field type. +func (t Text) FieldType() FieldType { + return FieldTypeVarChar +} + +func (t Text) Serialize() []byte { + return []byte(t) +} diff --git a/client/milvusclient/read.go b/client/milvusclient/read.go index 0730f93a69..bed2859a18 100644 --- a/client/milvusclient/read.go +++ b/client/milvusclient/read.go @@ -30,7 +30,10 @@ import ( ) func (c *Client) Search(ctx context.Context, option SearchOption, callOptions ...grpc.CallOption) ([]ResultSet, error) { - req := option.Request() + req, err := option.Request() + if err != nil { + return nil, err + } collection, err := c.getCollection(ctx, req.GetCollectionName()) if err != nil { return nil, err diff --git a/client/milvusclient/read_option_test.go b/client/milvusclient/read_option_test.go new file mode 100644 index 0000000000..0e50db0580 --- /dev/null +++ b/client/milvusclient/read_option_test.go @@ -0,0 +1,139 @@ +// Licensed to the LF AI & Data foundation under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package milvusclient + +import ( + "math/rand" + "testing" + + "github.com/stretchr/testify/suite" + + "github.com/milvus-io/milvus-proto/go-api/v2/commonpb" + "github.com/milvus-io/milvus/client/v2/entity" +) + +type SearchOptionSuite struct { + suite.Suite +} + +type nonSupportData struct{} + +func (d nonSupportData) Serialize() []byte { + return []byte{} +} + +func (d nonSupportData) Dim() int { + return 0 +} + +func (d nonSupportData) FieldType() entity.FieldType { + return entity.FieldType(0) +} + +func (s *SearchOptionSuite) TestBasic() { + collName := "search_opt_basic" + + topK := rand.Intn(100) + 1 + opt := NewSearchOption(collName, topK, []entity.Vector{entity.FloatVector([]float32{0.1, 0.2})}) + + opt = opt.WithANNSField("test_field").WithOutputFields("ID", "Value").WithConsistencyLevel(entity.ClStrong).WithFilter("ID > 1000") + + req, err := opt.Request() + s.Require().NoError(err) + + s.Equal(collName, req.GetCollectionName()) + s.Equal("ID > 1000", req.GetDsl()) + s.ElementsMatch([]string{"ID", "Value"}, req.GetOutputFields()) + searchParams := entity.KvPairsMap(req.GetSearchParams()) + annField, ok := searchParams[spAnnsField] + s.Require().True(ok) + s.Equal("test_field", annField) + + opt = NewSearchOption(collName, topK, []entity.Vector{nonSupportData{}}) + _, err = opt.Request() + s.Error(err) +} + +func (s *SearchOptionSuite) TestPlaceHolder() { + type testCase struct { + tag string + input []entity.Vector + expectError bool + expectType commonpb.PlaceholderType + } + + sparse, err := entity.NewSliceSparseEmbedding([]uint32{0, 10, 12}, []float32{0.1, 0.2, 0.3}) + s.Require().NoError(err) + + cases := []*testCase{ + { + tag: "empty_input", + input: nil, + expectType: commonpb.PlaceholderType_None, + }, + { + tag: "float_vector", + input: []entity.Vector{entity.FloatVector([]float32{0.1, 0.2, 0.3})}, + expectType: commonpb.PlaceholderType_FloatVector, + }, + { + tag: "sparse_vector", + input: []entity.Vector{sparse}, + expectType: commonpb.PlaceholderType_SparseFloatVector, + }, + { + tag: "fp16_vector", + input: []entity.Vector{entity.Float16Vector([]byte{})}, + expectType: commonpb.PlaceholderType_Float16Vector, + }, + { + tag: "bf16_vector", + input: []entity.Vector{entity.BFloat16Vector([]byte{})}, + expectType: commonpb.PlaceholderType_BFloat16Vector, + }, + { + tag: "binary_vector", + input: []entity.Vector{entity.BinaryVector([]byte{})}, + expectType: commonpb.PlaceholderType_BinaryVector, + }, + { + tag: "text", + input: []entity.Vector{entity.Text("abc")}, + expectType: commonpb.PlaceholderType_VarChar, + }, + { + tag: "non_supported", + input: []entity.Vector{nonSupportData{}}, + expectError: true, + }, + } + for _, tc := range cases { + s.Run(tc.tag, func() { + phv, err := vector2Placeholder(tc.input) + if tc.expectError { + s.Error(err) + } else { + s.NoError(err) + s.Equal(tc.expectType, phv.GetType()) + } + }) + } +} + +func TestSearchOption(t *testing.T) { + suite.Run(t, new(SearchOptionSuite)) +} diff --git a/client/milvusclient/read_options.go b/client/milvusclient/read_options.go index 8685c45de3..9f03d9a048 100644 --- a/client/milvusclient/read_options.go +++ b/client/milvusclient/read_options.go @@ -20,6 +20,7 @@ import ( "encoding/json" "strconv" + "github.com/cockroachdb/errors" "google.golang.org/protobuf/proto" "github.com/milvus-io/milvus-proto/go-api/v2/commonpb" @@ -40,7 +41,7 @@ const ( ) type SearchOption interface { - Request() *milvuspb.SearchRequest + Request() (*milvuspb.SearchRequest, error) } var _ SearchOption = (*searchOption)(nil) @@ -70,12 +71,12 @@ type annRequest struct { groupByField string } -func (opt *searchOption) Request() *milvuspb.SearchRequest { +func (opt *searchOption) Request() (*milvuspb.SearchRequest, error) { // TODO check whether search is hybrid after logic merged return opt.prepareSearchRequest(opt.request) } -func (opt *searchOption) prepareSearchRequest(annRequest *annRequest) *milvuspb.SearchRequest { +func (opt *searchOption) prepareSearchRequest(annRequest *annRequest) (*milvuspb.SearchRequest, error) { request := &milvuspb.SearchRequest{ CollectionName: opt.collectionName, PartitionNames: opt.partitionNames, @@ -104,11 +105,15 @@ func (opt *searchOption) prepareSearchRequest(annRequest *annRequest) *milvuspb. } request.SearchParams = entity.MapKvPairs(params) + var err error // placeholder group - request.PlaceholderGroup = vector2PlaceholderGroupBytes(annRequest.vectors) + request.PlaceholderGroup, err = vector2PlaceholderGroupBytes(annRequest.vectors) + if err != nil { + return nil, err + } } - return request + return request, nil } func (opt *searchOption) WithFilter(expr string) *searchOption { @@ -159,25 +164,29 @@ func NewSearchOption(collectionName string, limit int, vectors []entity.Vector) } } -func vector2PlaceholderGroupBytes(vectors []entity.Vector) []byte { +func vector2PlaceholderGroupBytes(vectors []entity.Vector) ([]byte, error) { + phv, err := vector2Placeholder(vectors) + if err != nil { + return nil, err + } phg := &commonpb.PlaceholderGroup{ Placeholders: []*commonpb.PlaceholderValue{ - vector2Placeholder(vectors), + phv, }, } - bs, _ := proto.Marshal(phg) - return bs + bs, err := proto.Marshal(phg) + return bs, err } -func vector2Placeholder(vectors []entity.Vector) *commonpb.PlaceholderValue { +func vector2Placeholder(vectors []entity.Vector) (*commonpb.PlaceholderValue, error) { var placeHolderType commonpb.PlaceholderType ph := &commonpb.PlaceholderValue{ Tag: "$0", Values: make([][]byte, 0, len(vectors)), } if len(vectors) == 0 { - return ph + return ph, nil } switch vectors[0].(type) { case entity.FloatVector: @@ -190,12 +199,16 @@ func vector2Placeholder(vectors []entity.Vector) *commonpb.PlaceholderValue { placeHolderType = commonpb.PlaceholderType_Float16Vector case entity.SparseEmbedding: placeHolderType = commonpb.PlaceholderType_SparseFloatVector + case entity.Text: + placeHolderType = commonpb.PlaceholderType_VarChar + default: + return nil, errors.Newf("unsupported search data type: %T", vectors[0]) } ph.Type = placeHolderType for _, vector := range vectors { ph.Values = append(ph.Values, vector.Serialize()) } - return ph + return ph, nil } type QueryOption interface { diff --git a/client/milvusclient/read_test.go b/client/milvusclient/read_test.go index 381f3e0431..b9840a9249 100644 --- a/client/milvusclient/read_test.go +++ b/client/milvusclient/read_test.go @@ -118,11 +118,14 @@ func (s *ReadSuite) TestSearch() { collectionName := fmt.Sprintf("coll_%s", s.randString(6)) s.setupCache(collectionName, s.schemaDyn) + _, err := s.client.Search(ctx, NewSearchOption(collectionName, 10, []entity.Vector{nonSupportData{}})) + s.Error(err) + s.mock.EXPECT().Search(mock.Anything, mock.Anything).RunAndReturn(func(ctx context.Context, sr *milvuspb.SearchRequest) (*milvuspb.SearchResults, error) { return nil, merr.WrapErrServiceInternal("mocked") }).Once() - _, err := s.client.Search(ctx, NewSearchOption(collectionName, 10, []entity.Vector{ + _, err = s.client.Search(ctx, NewSearchOption(collectionName, 10, []entity.Vector{ entity.FloatVector(lo.RepeatBy(128, func(_ int) float32 { return rand.Float32() })), diff --git a/client/milvusclient/write_options.go b/client/milvusclient/write_options.go index 826d36031c..4d243fe040 100644 --- a/client/milvusclient/write_options.go +++ b/client/milvusclient/write_options.go @@ -109,14 +109,15 @@ func (opt *columnBasedDataOption) processInsertColumns(colSchema *entity.Schema, } } - // check all fixed field pass value - for _, field := range colSchema.Fields { - _, has := mNameColumn[field.Name] - if !has && - !field.AutoID && !field.IsDynamic { - return nil, 0, fmt.Errorf("field %s not passed", field.Name) - } - } + // missing field shall be checked in server side + // // check all fixed field pass value + // for _, field := range colSchema.Fields { + // _, has := mNameColumn[field.Name] + // if !has && + // !field.AutoID && !field.IsDynamic { + // return nil, 0, fmt.Errorf("field %s not passed", field.Name) + // } + // } fieldsData := make([]*schemapb.FieldData, 0, len(mNameColumn)+1) for _, fixedColumn := range mNameColumn { diff --git a/client/milvusclient/write_test.go b/client/milvusclient/write_test.go index 133dd3fe2a..601ca242d2 100644 --- a/client/milvusclient/write_test.go +++ b/client/milvusclient/write_test.go @@ -129,10 +129,6 @@ func (s *WriteSuite) TestInsert() { } cases := []badCase{ - { - tag: "missing_column", - input: NewColumnBasedInsertOption(collName).WithInt64Column("id", []int64{1}), - }, { tag: "row_count_not_match", input: NewColumnBasedInsertOption(collName).WithInt64Column("id", []int64{1}). @@ -261,10 +257,6 @@ func (s *WriteSuite) TestUpsert() { } cases := []badCase{ - { - tag: "missing_column", - input: NewColumnBasedInsertOption(collName).WithInt64Column("id", []int64{1}), - }, { tag: "row_count_not_match", input: NewColumnBasedInsertOption(collName).WithInt64Column("id", []int64{1}). diff --git a/pkg/common/common.go b/pkg/common/common.go index 768577ccf9..86219ed5ae 100644 --- a/pkg/common/common.go +++ b/pkg/common/common.go @@ -140,6 +140,12 @@ const ( ConsistencyLevel = "consistency_level" ) +// Doc-in-doc-out +const ( + EnableAnalyzerKey = `enable_analyzer` + AnalyzerParamKey = `analyzer_params` +) + // Collection properties key const (