mirror of
https://gitee.com/milvus-io/milvus.git
synced 2025-12-06 17:18:35 +08:00
test: add test cases for int8 vector (#41957)
Signed-off-by: binbin lv <binbin.lv@zilliz.com>
This commit is contained in:
parent
1735f557ca
commit
eea6b50fbb
@ -368,6 +368,8 @@ class TestcaseBase(Base):
|
||||
# Unlike dense vectors, sparse vectors cannot create flat index.
|
||||
if DataType.SPARSE_FLOAT_VECTOR.name in vector_name:
|
||||
collection_w.create_index(vector_name, ct.default_sparse_inverted_index)
|
||||
elif vector_data_type == DataType.INT8_VECTOR:
|
||||
collection_w.create_index(vector_name, ct.int8_vector_index)
|
||||
else:
|
||||
collection_w.create_index(vector_name, ct.default_flat_index)
|
||||
|
||||
|
||||
@ -1100,4 +1100,26 @@ class TestMilvusClientV2Base(Base):
|
||||
source_group=source_group, target_group=target_group,
|
||||
collection_name=collection_name, num_replicas=num_replicas,
|
||||
**kwargs).run()
|
||||
return res, check_result
|
||||
|
||||
@trace()
|
||||
def create_field_schema(self, client, name, data_type, desc='', timeout=None, check_task=None, check_items=None, **kwargs):
|
||||
timeout = TIMEOUT if timeout is None else timeout
|
||||
kwargs.update({"timeout": timeout})
|
||||
|
||||
func_name = sys._getframe().f_code.co_name
|
||||
res, check = api_request([client.create_field_schema, name, data_type, desc], **kwargs)
|
||||
check_result = ResponseChecker(res, func_name, check_task, check_items, check,
|
||||
**kwargs).run()
|
||||
return res, check_result
|
||||
|
||||
@trace()
|
||||
def add_collection_field(self, client, collection_name, field_schema, timeout=None, check_task=None, check_items=None, **kwargs):
|
||||
timeout = TIMEOUT if timeout is None else timeout
|
||||
kwargs.update({"timeout": timeout})
|
||||
|
||||
func_name = sys._getframe().f_code.co_name
|
||||
res, check = api_request([client.add_collection_field, collection_name, field_schema], **kwargs)
|
||||
check_result = ResponseChecker(res, func_name, check_task, check_items, check,
|
||||
**kwargs).run()
|
||||
return res, check_result
|
||||
@ -7,9 +7,10 @@ from common import common_type as ct
|
||||
from common import common_func as cf
|
||||
from common.common_type import CheckTasks, Connect_Object_Name
|
||||
# from common.code_mapping import ErrorCode, ErrorMessage
|
||||
from pymilvus import Collection, Partition, ResourceGroupInfo
|
||||
from pymilvus import Collection, Partition, ResourceGroupInfo, DataType
|
||||
import check.param_check as pc
|
||||
|
||||
import numpy as np
|
||||
from ml_dtypes import bfloat16
|
||||
|
||||
class Error:
|
||||
def __init__(self, error):
|
||||
@ -259,8 +260,27 @@ class ResponseChecker:
|
||||
if check_items.get("id_name", "id"):
|
||||
assert res["fields"][0]["name"] == check_items.get("id_name", "id")
|
||||
if check_items.get("vector_name", "vector"):
|
||||
assert res["fields"][1]["name"] == check_items.get("vector_name", "vector")
|
||||
vector_name_list = []
|
||||
vector_name_list_expected = check_items.get("vector_name", "vector")
|
||||
for field in res["fields"]:
|
||||
if field["type"] in [101, 102, 103, 105]:
|
||||
vector_name_list.append(field["name"])
|
||||
if isinstance(vector_name_list_expected, str):
|
||||
assert vector_name_list[0] == check_items.get("vector_name", "vector")
|
||||
else:
|
||||
assert vector_name_list == vector_name_list_expected
|
||||
if check_items.get("dim", None) is not None:
|
||||
dim_list = []
|
||||
# here dim support int for only one vector field and list for multiple vector fields, and the order
|
||||
# should be the same of the order adding schema
|
||||
dim_list_expected = check_items.get("dim")
|
||||
for field in res["fields"]:
|
||||
if field["type"] in [101, 102, 103, 105]:
|
||||
dim_list.append(field["params"]["dim"])
|
||||
if isinstance(dim_list_expected, int):
|
||||
assert dim_list[0] == dim_list_expected
|
||||
else:
|
||||
assert dim_list == dim_list_expected
|
||||
assert res["fields"][1]["params"]["dim"] == check_items.get("dim")
|
||||
if check_items.get("nullable_fields", None) is not None:
|
||||
nullable_fields = check_items.get("nullable_fields")
|
||||
@ -272,7 +292,7 @@ class ResponseChecker:
|
||||
assert field["nullable"] is True
|
||||
assert res["fields"][0]["is_primary"] is True
|
||||
assert res["fields"][0]["field_id"] == 100 and (res["fields"][0]["type"] == 5 or 21)
|
||||
assert res["fields"][1]["field_id"] == 101 and res["fields"][1]["type"] == 101
|
||||
assert res["fields"][1]["field_id"] == 101 and (res["fields"][1]["type"] == 101 or 105)
|
||||
|
||||
return True
|
||||
|
||||
@ -540,6 +560,22 @@ class ResponseChecker:
|
||||
exp_res = check_items.get("exp_res", None)
|
||||
with_vec = check_items.get("with_vec", False)
|
||||
pk_name = check_items.get("pk_name", ct.default_primary_field_name)
|
||||
vector_type = check_items.get("vector_type", "FLOAT_VECTOR")
|
||||
if vector_type == DataType.FLOAT16_VECTOR:
|
||||
for single_exp_res in exp_res:
|
||||
single_exp_res['vector'] = single_exp_res['vector'] .tolist()
|
||||
for single_query_result in query_res:
|
||||
single_query_result['vector'] = np.frombuffer(single_query_result['vector'][0], dtype=np.float16).tolist()
|
||||
if vector_type == DataType.BFLOAT16_VECTOR:
|
||||
for single_exp_res in exp_res:
|
||||
single_exp_res['vector'] = single_exp_res['vector'] .tolist()
|
||||
for single_query_result in query_res:
|
||||
single_query_result['vector'] = np.frombuffer(single_query_result['vector'][0], dtype=bfloat16).tolist()
|
||||
if vector_type == DataType.INT8_VECTOR:
|
||||
for single_exp_res in exp_res:
|
||||
single_exp_res['vector'] = single_exp_res['vector'] .tolist()
|
||||
for single_query_result in query_res:
|
||||
single_query_result['vector'] = np.frombuffer(single_query_result['vector'][0], dtype=np.int8).tolist()
|
||||
if exp_res is not None:
|
||||
if isinstance(query_res, list):
|
||||
assert pc.equal_entities_list(exp=exp_res, actual=query_res, primary_field=pk_name,
|
||||
|
||||
@ -696,17 +696,18 @@ def gen_float_vec_field(name=ct.default_float_vec_field_name, is_primary=False,
|
||||
|
||||
if vector_data_type != DataType.SPARSE_FLOAT_VECTOR:
|
||||
float_vec_field, _ = ApiFieldSchemaWrapper().init_field_schema(name=name, dtype=vector_data_type,
|
||||
description=description, dim=dim,
|
||||
is_primary=is_primary, **kwargs)
|
||||
description=description, dim=dim,
|
||||
is_primary=is_primary, **kwargs)
|
||||
else:
|
||||
# no dim for sparse vector
|
||||
float_vec_field, _ = ApiFieldSchemaWrapper().init_field_schema(name=name, dtype=DataType.SPARSE_FLOAT_VECTOR,
|
||||
description=description,
|
||||
is_primary=is_primary, **kwargs)
|
||||
description=description,
|
||||
is_primary=is_primary, **kwargs)
|
||||
|
||||
return float_vec_field
|
||||
|
||||
|
||||
|
||||
def gen_binary_vec_field(name=ct.default_binary_vec_field_name, is_primary=False, dim=ct.default_dim,
|
||||
description=ct.default_desc, **kwargs):
|
||||
binary_vec_field, _ = ApiFieldSchemaWrapper().init_field_schema(name=name, dtype=DataType.BINARY_VECTOR,
|
||||
@ -792,7 +793,8 @@ def gen_default_collection_schema(description=ct.default_desc, primary_field=ct.
|
||||
|
||||
if len(multiple_dim_array) != 0:
|
||||
for other_dim in multiple_dim_array:
|
||||
fields.append(gen_float_vec_field(gen_unique_str("multiple_vector"), dim=other_dim,
|
||||
name_prefix = "multiple_vector"
|
||||
fields.append(gen_float_vec_field(gen_unique_str(name_prefix), dim=other_dim,
|
||||
vector_data_type=vector_data_type))
|
||||
|
||||
schema, _ = ApiCollectionSchemaWrapper().init_collection_schema(fields=fields, description=description,
|
||||
@ -1120,6 +1122,32 @@ def gen_schema_multi_string_fields(string_fields):
|
||||
return schema
|
||||
|
||||
|
||||
def gen_vectors(nb, dim, vector_data_type=DataType.FLOAT_VECTOR):
|
||||
vectors = []
|
||||
if vector_data_type == DataType.FLOAT_VECTOR:
|
||||
vectors = [[random.random() for _ in range(dim)] for _ in range(nb)]
|
||||
elif vector_data_type == DataType.FLOAT16_VECTOR:
|
||||
vectors = gen_fp16_vectors(nb, dim)[1]
|
||||
elif vector_data_type == DataType.BFLOAT16_VECTOR:
|
||||
vectors = gen_bf16_vectors(nb, dim)[1]
|
||||
elif vector_data_type == DataType.SPARSE_FLOAT_VECTOR:
|
||||
vectors = gen_sparse_vectors(nb, dim)
|
||||
elif vector_data_type == ct.text_sparse_vector:
|
||||
vectors = gen_text_vectors(nb) # for Full Text Search
|
||||
elif vector_data_type == DataType.INT8_VECTOR:
|
||||
vectors = gen_int8_vectors(nb, dim)[1]
|
||||
elif vector_data_type == DataType.BINARY_VECTOR:
|
||||
vectors = gen_binary_vectors(nb, dim)[1]
|
||||
else:
|
||||
log.error(f"Invalid vector data type: {vector_data_type}")
|
||||
raise Exception(f"Invalid vector data type: {vector_data_type}")
|
||||
if dim > 1:
|
||||
if vector_data_type == DataType.FLOAT_VECTOR:
|
||||
vectors = preprocessing.normalize(vectors, axis=1, norm='l2')
|
||||
vectors = vectors.tolist()
|
||||
return vectors
|
||||
|
||||
|
||||
def gen_string(nb):
|
||||
string_values = [str(random.random()) for _ in range(nb)]
|
||||
return string_values
|
||||
@ -3141,7 +3169,8 @@ def extract_vector_field_name_list(collection_w):
|
||||
if field['type'] == DataType.FLOAT_VECTOR \
|
||||
or field['type'] == DataType.FLOAT16_VECTOR \
|
||||
or field['type'] == DataType.BFLOAT16_VECTOR \
|
||||
or field['type'] == DataType.SPARSE_FLOAT_VECTOR:
|
||||
or field['type'] == DataType.SPARSE_FLOAT_VECTOR\
|
||||
or field['type'] == DataType.INT8_VECTOR:
|
||||
if field['name'] != ct.default_float_vec_field_name:
|
||||
vector_name_list.append(field['name'])
|
||||
|
||||
@ -3295,15 +3324,6 @@ def gen_sparse_vectors(nb, dim=1000, sparse_format="dok", empty_percentage=0):
|
||||
]
|
||||
return vectors
|
||||
|
||||
def gen_int8_vectors(num, dim):
|
||||
raw_vectors = []
|
||||
int8_vectors = []
|
||||
for _ in range(num):
|
||||
raw_vector = [random.randint(-128, 127) for _ in range(dim)]
|
||||
raw_vectors.append(raw_vector)
|
||||
int8_vector = np.array(raw_vector, dtype=np.int8)
|
||||
int8_vectors.append(int8_vector)
|
||||
return raw_vectors, int8_vectors
|
||||
|
||||
def gen_vectors(nb, dim, vector_data_type=DataType.FLOAT_VECTOR):
|
||||
vectors = []
|
||||
@ -3331,6 +3351,17 @@ def gen_vectors(nb, dim, vector_data_type=DataType.FLOAT_VECTOR):
|
||||
return vectors
|
||||
|
||||
|
||||
def gen_int8_vectors(num, dim):
|
||||
raw_vectors = []
|
||||
int8_vectors = []
|
||||
for _ in range(num):
|
||||
raw_vector = [random.randint(-128, 127) for _ in range(dim)]
|
||||
raw_vectors.append(raw_vector)
|
||||
int8_vector = np.array(raw_vector, dtype=np.int8)
|
||||
int8_vectors.append(int8_vector)
|
||||
return raw_vectors, int8_vectors
|
||||
|
||||
|
||||
def gen_text_vectors(nb, language="en"):
|
||||
|
||||
fake = Faker("en_US")
|
||||
@ -3339,6 +3370,7 @@ def gen_text_vectors(nb, language="en"):
|
||||
vectors = [" milvus " + fake.text() for _ in range(nb)]
|
||||
return vectors
|
||||
|
||||
|
||||
def field_types() -> dict:
|
||||
return dict(sorted(dict(DataType.__members__).items(), key=lambda item: item[0], reverse=True))
|
||||
|
||||
|
||||
@ -68,10 +68,11 @@ default_metric_for_vector_type = {
|
||||
DataType.BINARY_VECTOR: "HAMMING",
|
||||
}
|
||||
|
||||
all_dense_vector_types = [DataType.FLOAT_VECTOR, DataType.FLOAT16_VECTOR, DataType.BFLOAT16_VECTOR]
|
||||
all_float_vector_dtypes = [DataType.FLOAT_VECTOR, DataType.FLOAT16_VECTOR, DataType.BFLOAT16_VECTOR, DataType.SPARSE_FLOAT_VECTOR]
|
||||
|
||||
append_vector_type = [DataType.FLOAT16_VECTOR, DataType.BFLOAT16_VECTOR, DataType.SPARSE_FLOAT_VECTOR]
|
||||
append_vector_type = [DataType.FLOAT16_VECTOR, DataType.BFLOAT16_VECTOR, DataType.SPARSE_FLOAT_VECTOR, DataType.INT8_VECTOR]
|
||||
all_dense_vector_types = [DataType.FLOAT_VECTOR, DataType.FLOAT16_VECTOR, DataType.BFLOAT16_VECTOR, DataType.INT8_VECTOR]
|
||||
all_float_vector_dtypes = [DataType.FLOAT_VECTOR, DataType.FLOAT16_VECTOR, DataType.BFLOAT16_VECTOR, DataType.SPARSE_FLOAT_VECTOR, DataType.INT8_VECTOR]
|
||||
all_vector_data_types = [DataType.FLOAT_VECTOR, DataType.FLOAT16_VECTOR, DataType.BFLOAT16_VECTOR, DataType.SPARSE_FLOAT_VECTOR, DataType.INT8_VECTOR]
|
||||
default_sparse_vec_field_name = "sparse_vector"
|
||||
default_partition_name = "_default"
|
||||
default_resource_group_name = '__default_resource_group'
|
||||
@ -254,6 +255,8 @@ all_index_types = ["FLAT", "IVF_FLAT", "IVF_SQ8", "IVF_PQ",
|
||||
|
||||
inverted_index_algo = ['TAAT_NAIVE', 'DAAT_WAND', 'DAAT_MAXSCORE']
|
||||
|
||||
int8_vector_index = ["HNSW"]
|
||||
|
||||
default_all_indexes_params = [{}, {"nlist": 128}, {"nlist": 128}, {"nlist": 128, "m": 16, "nbits": 8},
|
||||
{"nlist": 128, "refine": 'true', "refine_type": "SQ8"},
|
||||
{"M": 32, "efConstruction": 360}, {"nlist": 128}, {},
|
||||
|
||||
@ -278,9 +278,10 @@ class TestMilvusClientCollectionValid(TestMilvusClientV2Base):
|
||||
|
||||
@pytest.mark.tags(CaseLabel.L0)
|
||||
@pytest.mark.parametrize("nullable", [True, False])
|
||||
def test_milvus_client_collection_self_creation_default(self, nullable):
|
||||
@pytest.mark.parametrize("vector_type", [DataType.FLOAT_VECTOR, DataType.INT8_VECTOR])
|
||||
def test_milvus_client_collection_self_creation_default(self, nullable, vector_type):
|
||||
"""
|
||||
target: test fast create collection normal case
|
||||
target: test self create collection normal case
|
||||
method: create collection
|
||||
expected: create collection with default schema, index, and load successfully
|
||||
"""
|
||||
@ -290,7 +291,7 @@ class TestMilvusClientCollectionValid(TestMilvusClientV2Base):
|
||||
# 1. create collection
|
||||
schema = self.create_schema(client, enable_dynamic_field=False)[0]
|
||||
schema.add_field("id_string", DataType.VARCHAR, max_length=64, is_primary=True, auto_id=False)
|
||||
schema.add_field("embeddings", DataType.FLOAT_VECTOR, dim=dim)
|
||||
schema.add_field("embeddings", vector_type, dim=dim)
|
||||
schema.add_field("title", DataType.VARCHAR, max_length=64, is_partition_key=True)
|
||||
schema.add_field("nullable_field", DataType.INT64, nullable=nullable, default_value=10)
|
||||
schema.add_field("array_field", DataType.ARRAY, element_type=DataType.INT64, max_capacity=12,
|
||||
@ -318,6 +319,46 @@ class TestMilvusClientCollectionValid(TestMilvusClientV2Base):
|
||||
if self.has_collection(client, collection_name)[0]:
|
||||
self.drop_collection(client, collection_name)
|
||||
|
||||
@pytest.mark.tags(CaseLabel.L2)
|
||||
def test_milvus_client_collection_self_creation_multiple_vectors(self):
|
||||
"""
|
||||
target: test self create collection with multiple vectors
|
||||
method: create collection with multiple vectors
|
||||
expected: create collection with default schema, index, and load successfully
|
||||
"""
|
||||
client = self._client()
|
||||
collection_name = cf.gen_unique_str(prefix)
|
||||
dim = 128
|
||||
# 1. create collection
|
||||
schema = self.create_schema(client, enable_dynamic_field=False)[0]
|
||||
schema.add_field("id_int64", DataType.INT64, is_primary=True, auto_id=False)
|
||||
schema.add_field("embeddings", DataType.FLOAT_VECTOR, dim=dim)
|
||||
schema.add_field("int8embeddings_1", DataType.INT8_VECTOR, dim=dim * 2)
|
||||
schema.add_field("int8embeddings_2", DataType.FLOAT16_VECTOR, dim=int(dim / 2))
|
||||
schema.add_field("int8embeddings_3", DataType.BFLOAT16_VECTOR, dim=int(dim / 2))
|
||||
index_params = self.prepare_index_params(client)[0]
|
||||
index_params.add_index("embeddings", metric_type="COSINE")
|
||||
index_params.add_index("embeddings_1", metric_type="IP")
|
||||
index_params.add_index("embeddings_2", metric_type="L2")
|
||||
index_params.add_index("embeddings_3", metric_type="COSINE")
|
||||
# index_params.add_index("title")
|
||||
self.create_collection(client, collection_name, dimension=dim, schema=schema, index_params=index_params)
|
||||
collections = self.list_collections(client)[0]
|
||||
assert collection_name in collections
|
||||
check_items = {"collection_name": collection_name,
|
||||
"dim": [dim, dim * 2, dim / 2, dim / 2],
|
||||
"consistency_level": 0,
|
||||
"enable_dynamic_field": False,
|
||||
"id_name": "id_int64",
|
||||
"vector_name": ["embeddings", "embeddings_1", "embeddings_2", "embeddings_3"]}
|
||||
self.describe_collection(client, collection_name,
|
||||
check_task=CheckTasks.check_describe_collection_property,
|
||||
check_items=check_items)
|
||||
index = self.list_indexes(client, collection_name)[0]
|
||||
assert sorted(index) == sorted(['embeddings', 'embeddings_1', 'embeddings_2', 'embeddings_3'])
|
||||
if self.has_collection(client, collection_name)[0]:
|
||||
self.drop_collection(client, collection_name)
|
||||
|
||||
@pytest.mark.tags(CaseLabel.L1)
|
||||
def test_milvus_client_array_insert_search(self):
|
||||
"""
|
||||
|
||||
@ -207,6 +207,72 @@ class TestMilvusClientIndexInvalid(TestMilvusClientV2Base):
|
||||
check_task=CheckTasks.err_res, check_items=error)
|
||||
self.drop_collection(client, collection_name)
|
||||
|
||||
@pytest.mark.tags(CaseLabel.L1)
|
||||
@pytest.mark.parametrize("not_supported_index", ct.all_index_types[:-2])
|
||||
def test_milvus_client_int8_vector_create_not_supported_cpu_index(self, not_supported_index):
|
||||
"""
|
||||
target: test create non-supported index on int8 vector
|
||||
method: create non-supported index on int8 vector
|
||||
expected: raise exception
|
||||
"""
|
||||
if not_supported_index in ct.int8_vector_index:
|
||||
pytest.skip("This index is supported by int8 vector")
|
||||
client = self._client()
|
||||
collection_name = cf.gen_unique_str(prefix)
|
||||
dim = 128
|
||||
# 1. create collection
|
||||
schema = self.create_schema(client, enable_dynamic_field=False)[0]
|
||||
schema.add_field("id_string", DataType.VARCHAR, max_length=64, is_primary=True, auto_id=False)
|
||||
schema.add_field("embeddings", DataType.INT8_VECTOR, dim=dim)
|
||||
index_params = self.prepare_index_params(client)[0]
|
||||
index_params.add_index("embeddings", metric_type="COSINE")
|
||||
# 2. index_params.add_index("title")
|
||||
self.create_collection(client, collection_name, dimension=dim, schema=schema, index_params=index_params)
|
||||
self.release_collection(client, collection_name)
|
||||
self.drop_index(client, collection_name, "embeddings")
|
||||
# 3. prepare index params
|
||||
index_params = self.prepare_index_params(client)[0]
|
||||
index_params.add_index(field_name="embeddings", index_type=not_supported_index, metric_type="L2")
|
||||
# 4. create another index
|
||||
error = {ct.err_code: 1100, ct.err_msg: f"data type Int8Vector can't build with this index {not_supported_index}: "
|
||||
f"invalid parameter[expected=valid index params][actual=invalid index params]"}
|
||||
self.create_index(client, collection_name, index_params,
|
||||
check_task=CheckTasks.err_res, check_items=error)
|
||||
self.drop_collection(client, collection_name)
|
||||
|
||||
@pytest.mark.tags(CaseLabel.L1)
|
||||
@pytest.mark.parametrize("not_supported_index", ct.all_index_types[-2:])
|
||||
def test_milvus_client_int8_vector_create_not_supported_GPU_index(self, not_supported_index):
|
||||
"""
|
||||
target: test create non-supported index on int8 vector
|
||||
method: create non-supported index on int8 vector
|
||||
expected: raise exception
|
||||
"""
|
||||
if not_supported_index in ct.int8_vector_index:
|
||||
pytest.skip("This index is supported by int8 vector")
|
||||
client = self._client()
|
||||
collection_name = cf.gen_unique_str(prefix)
|
||||
dim = 128
|
||||
# 1. create collection
|
||||
schema = self.create_schema(client, enable_dynamic_field=False)[0]
|
||||
schema.add_field("id_string", DataType.VARCHAR, max_length=64, is_primary=True, auto_id=False)
|
||||
schema.add_field("embeddings", DataType.INT8_VECTOR, dim=dim)
|
||||
index_params = self.prepare_index_params(client)[0]
|
||||
index_params.add_index("embeddings", metric_type="COSINE")
|
||||
# 2. index_params.add_index("title")
|
||||
self.create_collection(client, collection_name, dimension=dim, schema=schema, index_params=index_params)
|
||||
self.release_collection(client, collection_name)
|
||||
self.drop_index(client, collection_name, "embeddings")
|
||||
# 3. prepare index params
|
||||
index_params = self.prepare_index_params(client)[0]
|
||||
index_params.add_index(field_name="embeddings", index_type=not_supported_index, metric_type="L2")
|
||||
# 4. create another index
|
||||
error = {ct.err_code: 1100, ct.err_msg: f"invalid parameter[expected=valid index][actual=invalid "
|
||||
f"index type: {not_supported_index}"}
|
||||
self.create_index(client, collection_name, index_params,
|
||||
check_task=CheckTasks.err_res, check_items=error)
|
||||
self.drop_collection(client, collection_name)
|
||||
|
||||
|
||||
class TestMilvusClientIndexValid(TestMilvusClientV2Base):
|
||||
""" Test case of index interface """
|
||||
|
||||
@ -332,6 +332,15 @@ class TestMilvusClientInsertValid(TestMilvusClientV2Base):
|
||||
def metric_type(self, request):
|
||||
yield request.param
|
||||
|
||||
@pytest.fixture(scope="function", params=[True, False])
|
||||
def nullable(self, request):
|
||||
yield request.param
|
||||
|
||||
@pytest.fixture(scope="function", params=[DataType.FLOAT_VECTOR, DataType.FLOAT16_VECTOR,
|
||||
DataType.BFLOAT16_VECTOR, DataType.INT8_VECTOR])
|
||||
def vector_type(self, request):
|
||||
yield request.param
|
||||
|
||||
"""
|
||||
******************************************************************
|
||||
# The following are valid base cases
|
||||
@ -339,7 +348,7 @@ class TestMilvusClientInsertValid(TestMilvusClientV2Base):
|
||||
"""
|
||||
|
||||
@pytest.mark.tags(CaseLabel.L0)
|
||||
def test_milvus_client_insert_default(self):
|
||||
def test_milvus_client_insert_default(self, vector_type, nullable):
|
||||
"""
|
||||
target: test search (high level api) normal case
|
||||
method: create connection, collection, insert and search
|
||||
@ -348,22 +357,25 @@ class TestMilvusClientInsertValid(TestMilvusClientV2Base):
|
||||
client = self._client()
|
||||
collection_name = cf.gen_unique_str(prefix)
|
||||
# 1. create collection
|
||||
self.create_collection(client, collection_name, default_dim, consistency_level="Strong")
|
||||
collections = self.list_collections(client)[0]
|
||||
assert collection_name in collections
|
||||
self.describe_collection(client, collection_name,
|
||||
check_task=CheckTasks.check_describe_collection_property,
|
||||
check_items={"collection_name": collection_name,
|
||||
"dim": default_dim,
|
||||
"consistency_level": 0})
|
||||
dim = 8
|
||||
# 1. create collection
|
||||
schema = self.create_schema(client, enable_dynamic_field=False)[0]
|
||||
schema.add_field(default_primary_key_field_name, DataType.INT64, max_length=64, is_primary=True, auto_id=False)
|
||||
schema.add_field(default_vector_field_name, vector_type, dim=dim)
|
||||
schema.add_field(default_string_field_name, DataType.VARCHAR, max_length=64, is_partition_key=True)
|
||||
schema.add_field(default_float_field_name, DataType.FLOAT, nullable=nullable)
|
||||
index_params = self.prepare_index_params(client)[0]
|
||||
index_params.add_index(default_vector_field_name, metric_type="COSINE")
|
||||
self.create_collection(client, collection_name, dimension=dim, schema=schema, index_params=index_params)
|
||||
# 2. insert
|
||||
rng = np.random.default_rng(seed=19530)
|
||||
rows = [{default_primary_key_field_name: i, default_vector_field_name: list(rng.random((1, default_dim))[0]),
|
||||
vectors = cf.gen_vectors(default_nb, dim, vector_data_type=vector_type)
|
||||
rows = [{default_primary_key_field_name: i, default_vector_field_name: vectors[i],
|
||||
default_float_field_name: i * 1.0, default_string_field_name: str(i)} for i in range(default_nb)]
|
||||
results = self.insert(client, collection_name, rows)[0]
|
||||
assert results['insert_count'] == default_nb
|
||||
# 3. search
|
||||
vectors_to_search = rng.random((1, default_dim))
|
||||
vectors_to_search = [vectors[0]]
|
||||
insert_ids = [i for i in range(default_nb)]
|
||||
self.search(client, collection_name, vectors_to_search,
|
||||
check_task=CheckTasks.check_search_results,
|
||||
@ -377,7 +389,8 @@ class TestMilvusClientInsertValid(TestMilvusClientV2Base):
|
||||
check_task=CheckTasks.check_query_results,
|
||||
check_items={exp_res: rows,
|
||||
"with_vec": True,
|
||||
"pk_name": default_primary_key_field_name})
|
||||
"pk_name": default_primary_key_field_name,
|
||||
"vector_type": vector_type})
|
||||
self.release_collection(client, collection_name)
|
||||
self.drop_collection(client, collection_name)
|
||||
|
||||
|
||||
@ -743,8 +743,11 @@ class TestCollectionSearchInvalid(TestcaseBase):
|
||||
check_items={"err_code": 101,
|
||||
"err_msg": err_msg})
|
||||
# 3. search collection without data after load
|
||||
collection_w.create_index(
|
||||
ct.default_float_vec_field_name, index_params=ct.default_flat_index)
|
||||
if vector_data_type == DataType.INT8_VECTOR:
|
||||
collection_w.create_index(ct.default_float_vec_field_name,
|
||||
index_params={"index_type": "HNSW", "metric_type": "L2"})
|
||||
else:
|
||||
collection_w.create_index(ct.default_float_vec_field_name, index_params=ct.default_flat_index)
|
||||
collection_w.load()
|
||||
collection_w.search(vectors[:default_nq], default_search_field, default_search_params,
|
||||
default_limit, default_search_exp,
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user