test: add test cases for int8 vector (#41957)

Signed-off-by: binbin lv <binbin.lv@zilliz.com>
This commit is contained in:
binbin 2025-05-23 09:24:28 +08:00 committed by GitHub
parent 1735f557ca
commit eea6b50fbb
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
9 changed files with 257 additions and 39 deletions

View File

@ -368,6 +368,8 @@ class TestcaseBase(Base):
# Unlike dense vectors, sparse vectors cannot create flat index.
if DataType.SPARSE_FLOAT_VECTOR.name in vector_name:
collection_w.create_index(vector_name, ct.default_sparse_inverted_index)
elif vector_data_type == DataType.INT8_VECTOR:
collection_w.create_index(vector_name, ct.int8_vector_index)
else:
collection_w.create_index(vector_name, ct.default_flat_index)

View File

@ -1100,4 +1100,26 @@ class TestMilvusClientV2Base(Base):
source_group=source_group, target_group=target_group,
collection_name=collection_name, num_replicas=num_replicas,
**kwargs).run()
return res, check_result
@trace()
def create_field_schema(self, client, name, data_type, desc='', timeout=None, check_task=None, check_items=None, **kwargs):
timeout = TIMEOUT if timeout is None else timeout
kwargs.update({"timeout": timeout})
func_name = sys._getframe().f_code.co_name
res, check = api_request([client.create_field_schema, name, data_type, desc], **kwargs)
check_result = ResponseChecker(res, func_name, check_task, check_items, check,
**kwargs).run()
return res, check_result
@trace()
def add_collection_field(self, client, collection_name, field_schema, timeout=None, check_task=None, check_items=None, **kwargs):
timeout = TIMEOUT if timeout is None else timeout
kwargs.update({"timeout": timeout})
func_name = sys._getframe().f_code.co_name
res, check = api_request([client.add_collection_field, collection_name, field_schema], **kwargs)
check_result = ResponseChecker(res, func_name, check_task, check_items, check,
**kwargs).run()
return res, check_result

View File

@ -7,9 +7,10 @@ from common import common_type as ct
from common import common_func as cf
from common.common_type import CheckTasks, Connect_Object_Name
# from common.code_mapping import ErrorCode, ErrorMessage
from pymilvus import Collection, Partition, ResourceGroupInfo
from pymilvus import Collection, Partition, ResourceGroupInfo, DataType
import check.param_check as pc
import numpy as np
from ml_dtypes import bfloat16
class Error:
def __init__(self, error):
@ -259,8 +260,27 @@ class ResponseChecker:
if check_items.get("id_name", "id"):
assert res["fields"][0]["name"] == check_items.get("id_name", "id")
if check_items.get("vector_name", "vector"):
assert res["fields"][1]["name"] == check_items.get("vector_name", "vector")
vector_name_list = []
vector_name_list_expected = check_items.get("vector_name", "vector")
for field in res["fields"]:
if field["type"] in [101, 102, 103, 105]:
vector_name_list.append(field["name"])
if isinstance(vector_name_list_expected, str):
assert vector_name_list[0] == check_items.get("vector_name", "vector")
else:
assert vector_name_list == vector_name_list_expected
if check_items.get("dim", None) is not None:
dim_list = []
# here dim support int for only one vector field and list for multiple vector fields, and the order
# should be the same of the order adding schema
dim_list_expected = check_items.get("dim")
for field in res["fields"]:
if field["type"] in [101, 102, 103, 105]:
dim_list.append(field["params"]["dim"])
if isinstance(dim_list_expected, int):
assert dim_list[0] == dim_list_expected
else:
assert dim_list == dim_list_expected
assert res["fields"][1]["params"]["dim"] == check_items.get("dim")
if check_items.get("nullable_fields", None) is not None:
nullable_fields = check_items.get("nullable_fields")
@ -272,7 +292,7 @@ class ResponseChecker:
assert field["nullable"] is True
assert res["fields"][0]["is_primary"] is True
assert res["fields"][0]["field_id"] == 100 and (res["fields"][0]["type"] == 5 or 21)
assert res["fields"][1]["field_id"] == 101 and res["fields"][1]["type"] == 101
assert res["fields"][1]["field_id"] == 101 and (res["fields"][1]["type"] == 101 or 105)
return True
@ -540,6 +560,22 @@ class ResponseChecker:
exp_res = check_items.get("exp_res", None)
with_vec = check_items.get("with_vec", False)
pk_name = check_items.get("pk_name", ct.default_primary_field_name)
vector_type = check_items.get("vector_type", "FLOAT_VECTOR")
if vector_type == DataType.FLOAT16_VECTOR:
for single_exp_res in exp_res:
single_exp_res['vector'] = single_exp_res['vector'] .tolist()
for single_query_result in query_res:
single_query_result['vector'] = np.frombuffer(single_query_result['vector'][0], dtype=np.float16).tolist()
if vector_type == DataType.BFLOAT16_VECTOR:
for single_exp_res in exp_res:
single_exp_res['vector'] = single_exp_res['vector'] .tolist()
for single_query_result in query_res:
single_query_result['vector'] = np.frombuffer(single_query_result['vector'][0], dtype=bfloat16).tolist()
if vector_type == DataType.INT8_VECTOR:
for single_exp_res in exp_res:
single_exp_res['vector'] = single_exp_res['vector'] .tolist()
for single_query_result in query_res:
single_query_result['vector'] = np.frombuffer(single_query_result['vector'][0], dtype=np.int8).tolist()
if exp_res is not None:
if isinstance(query_res, list):
assert pc.equal_entities_list(exp=exp_res, actual=query_res, primary_field=pk_name,

View File

@ -696,17 +696,18 @@ def gen_float_vec_field(name=ct.default_float_vec_field_name, is_primary=False,
if vector_data_type != DataType.SPARSE_FLOAT_VECTOR:
float_vec_field, _ = ApiFieldSchemaWrapper().init_field_schema(name=name, dtype=vector_data_type,
description=description, dim=dim,
is_primary=is_primary, **kwargs)
description=description, dim=dim,
is_primary=is_primary, **kwargs)
else:
# no dim for sparse vector
float_vec_field, _ = ApiFieldSchemaWrapper().init_field_schema(name=name, dtype=DataType.SPARSE_FLOAT_VECTOR,
description=description,
is_primary=is_primary, **kwargs)
description=description,
is_primary=is_primary, **kwargs)
return float_vec_field
def gen_binary_vec_field(name=ct.default_binary_vec_field_name, is_primary=False, dim=ct.default_dim,
description=ct.default_desc, **kwargs):
binary_vec_field, _ = ApiFieldSchemaWrapper().init_field_schema(name=name, dtype=DataType.BINARY_VECTOR,
@ -792,7 +793,8 @@ def gen_default_collection_schema(description=ct.default_desc, primary_field=ct.
if len(multiple_dim_array) != 0:
for other_dim in multiple_dim_array:
fields.append(gen_float_vec_field(gen_unique_str("multiple_vector"), dim=other_dim,
name_prefix = "multiple_vector"
fields.append(gen_float_vec_field(gen_unique_str(name_prefix), dim=other_dim,
vector_data_type=vector_data_type))
schema, _ = ApiCollectionSchemaWrapper().init_collection_schema(fields=fields, description=description,
@ -1120,6 +1122,32 @@ def gen_schema_multi_string_fields(string_fields):
return schema
def gen_vectors(nb, dim, vector_data_type=DataType.FLOAT_VECTOR):
vectors = []
if vector_data_type == DataType.FLOAT_VECTOR:
vectors = [[random.random() for _ in range(dim)] for _ in range(nb)]
elif vector_data_type == DataType.FLOAT16_VECTOR:
vectors = gen_fp16_vectors(nb, dim)[1]
elif vector_data_type == DataType.BFLOAT16_VECTOR:
vectors = gen_bf16_vectors(nb, dim)[1]
elif vector_data_type == DataType.SPARSE_FLOAT_VECTOR:
vectors = gen_sparse_vectors(nb, dim)
elif vector_data_type == ct.text_sparse_vector:
vectors = gen_text_vectors(nb) # for Full Text Search
elif vector_data_type == DataType.INT8_VECTOR:
vectors = gen_int8_vectors(nb, dim)[1]
elif vector_data_type == DataType.BINARY_VECTOR:
vectors = gen_binary_vectors(nb, dim)[1]
else:
log.error(f"Invalid vector data type: {vector_data_type}")
raise Exception(f"Invalid vector data type: {vector_data_type}")
if dim > 1:
if vector_data_type == DataType.FLOAT_VECTOR:
vectors = preprocessing.normalize(vectors, axis=1, norm='l2')
vectors = vectors.tolist()
return vectors
def gen_string(nb):
string_values = [str(random.random()) for _ in range(nb)]
return string_values
@ -3141,7 +3169,8 @@ def extract_vector_field_name_list(collection_w):
if field['type'] == DataType.FLOAT_VECTOR \
or field['type'] == DataType.FLOAT16_VECTOR \
or field['type'] == DataType.BFLOAT16_VECTOR \
or field['type'] == DataType.SPARSE_FLOAT_VECTOR:
or field['type'] == DataType.SPARSE_FLOAT_VECTOR\
or field['type'] == DataType.INT8_VECTOR:
if field['name'] != ct.default_float_vec_field_name:
vector_name_list.append(field['name'])
@ -3295,15 +3324,6 @@ def gen_sparse_vectors(nb, dim=1000, sparse_format="dok", empty_percentage=0):
]
return vectors
def gen_int8_vectors(num, dim):
raw_vectors = []
int8_vectors = []
for _ in range(num):
raw_vector = [random.randint(-128, 127) for _ in range(dim)]
raw_vectors.append(raw_vector)
int8_vector = np.array(raw_vector, dtype=np.int8)
int8_vectors.append(int8_vector)
return raw_vectors, int8_vectors
def gen_vectors(nb, dim, vector_data_type=DataType.FLOAT_VECTOR):
vectors = []
@ -3331,6 +3351,17 @@ def gen_vectors(nb, dim, vector_data_type=DataType.FLOAT_VECTOR):
return vectors
def gen_int8_vectors(num, dim):
raw_vectors = []
int8_vectors = []
for _ in range(num):
raw_vector = [random.randint(-128, 127) for _ in range(dim)]
raw_vectors.append(raw_vector)
int8_vector = np.array(raw_vector, dtype=np.int8)
int8_vectors.append(int8_vector)
return raw_vectors, int8_vectors
def gen_text_vectors(nb, language="en"):
fake = Faker("en_US")
@ -3339,6 +3370,7 @@ def gen_text_vectors(nb, language="en"):
vectors = [" milvus " + fake.text() for _ in range(nb)]
return vectors
def field_types() -> dict:
return dict(sorted(dict(DataType.__members__).items(), key=lambda item: item[0], reverse=True))

View File

@ -68,10 +68,11 @@ default_metric_for_vector_type = {
DataType.BINARY_VECTOR: "HAMMING",
}
all_dense_vector_types = [DataType.FLOAT_VECTOR, DataType.FLOAT16_VECTOR, DataType.BFLOAT16_VECTOR]
all_float_vector_dtypes = [DataType.FLOAT_VECTOR, DataType.FLOAT16_VECTOR, DataType.BFLOAT16_VECTOR, DataType.SPARSE_FLOAT_VECTOR]
append_vector_type = [DataType.FLOAT16_VECTOR, DataType.BFLOAT16_VECTOR, DataType.SPARSE_FLOAT_VECTOR]
append_vector_type = [DataType.FLOAT16_VECTOR, DataType.BFLOAT16_VECTOR, DataType.SPARSE_FLOAT_VECTOR, DataType.INT8_VECTOR]
all_dense_vector_types = [DataType.FLOAT_VECTOR, DataType.FLOAT16_VECTOR, DataType.BFLOAT16_VECTOR, DataType.INT8_VECTOR]
all_float_vector_dtypes = [DataType.FLOAT_VECTOR, DataType.FLOAT16_VECTOR, DataType.BFLOAT16_VECTOR, DataType.SPARSE_FLOAT_VECTOR, DataType.INT8_VECTOR]
all_vector_data_types = [DataType.FLOAT_VECTOR, DataType.FLOAT16_VECTOR, DataType.BFLOAT16_VECTOR, DataType.SPARSE_FLOAT_VECTOR, DataType.INT8_VECTOR]
default_sparse_vec_field_name = "sparse_vector"
default_partition_name = "_default"
default_resource_group_name = '__default_resource_group'
@ -254,6 +255,8 @@ all_index_types = ["FLAT", "IVF_FLAT", "IVF_SQ8", "IVF_PQ",
inverted_index_algo = ['TAAT_NAIVE', 'DAAT_WAND', 'DAAT_MAXSCORE']
int8_vector_index = ["HNSW"]
default_all_indexes_params = [{}, {"nlist": 128}, {"nlist": 128}, {"nlist": 128, "m": 16, "nbits": 8},
{"nlist": 128, "refine": 'true', "refine_type": "SQ8"},
{"M": 32, "efConstruction": 360}, {"nlist": 128}, {},

View File

@ -278,9 +278,10 @@ class TestMilvusClientCollectionValid(TestMilvusClientV2Base):
@pytest.mark.tags(CaseLabel.L0)
@pytest.mark.parametrize("nullable", [True, False])
def test_milvus_client_collection_self_creation_default(self, nullable):
@pytest.mark.parametrize("vector_type", [DataType.FLOAT_VECTOR, DataType.INT8_VECTOR])
def test_milvus_client_collection_self_creation_default(self, nullable, vector_type):
"""
target: test fast create collection normal case
target: test self create collection normal case
method: create collection
expected: create collection with default schema, index, and load successfully
"""
@ -290,7 +291,7 @@ class TestMilvusClientCollectionValid(TestMilvusClientV2Base):
# 1. create collection
schema = self.create_schema(client, enable_dynamic_field=False)[0]
schema.add_field("id_string", DataType.VARCHAR, max_length=64, is_primary=True, auto_id=False)
schema.add_field("embeddings", DataType.FLOAT_VECTOR, dim=dim)
schema.add_field("embeddings", vector_type, dim=dim)
schema.add_field("title", DataType.VARCHAR, max_length=64, is_partition_key=True)
schema.add_field("nullable_field", DataType.INT64, nullable=nullable, default_value=10)
schema.add_field("array_field", DataType.ARRAY, element_type=DataType.INT64, max_capacity=12,
@ -318,6 +319,46 @@ class TestMilvusClientCollectionValid(TestMilvusClientV2Base):
if self.has_collection(client, collection_name)[0]:
self.drop_collection(client, collection_name)
@pytest.mark.tags(CaseLabel.L2)
def test_milvus_client_collection_self_creation_multiple_vectors(self):
"""
target: test self create collection with multiple vectors
method: create collection with multiple vectors
expected: create collection with default schema, index, and load successfully
"""
client = self._client()
collection_name = cf.gen_unique_str(prefix)
dim = 128
# 1. create collection
schema = self.create_schema(client, enable_dynamic_field=False)[0]
schema.add_field("id_int64", DataType.INT64, is_primary=True, auto_id=False)
schema.add_field("embeddings", DataType.FLOAT_VECTOR, dim=dim)
schema.add_field("int8embeddings_1", DataType.INT8_VECTOR, dim=dim * 2)
schema.add_field("int8embeddings_2", DataType.FLOAT16_VECTOR, dim=int(dim / 2))
schema.add_field("int8embeddings_3", DataType.BFLOAT16_VECTOR, dim=int(dim / 2))
index_params = self.prepare_index_params(client)[0]
index_params.add_index("embeddings", metric_type="COSINE")
index_params.add_index("embeddings_1", metric_type="IP")
index_params.add_index("embeddings_2", metric_type="L2")
index_params.add_index("embeddings_3", metric_type="COSINE")
# index_params.add_index("title")
self.create_collection(client, collection_name, dimension=dim, schema=schema, index_params=index_params)
collections = self.list_collections(client)[0]
assert collection_name in collections
check_items = {"collection_name": collection_name,
"dim": [dim, dim * 2, dim / 2, dim / 2],
"consistency_level": 0,
"enable_dynamic_field": False,
"id_name": "id_int64",
"vector_name": ["embeddings", "embeddings_1", "embeddings_2", "embeddings_3"]}
self.describe_collection(client, collection_name,
check_task=CheckTasks.check_describe_collection_property,
check_items=check_items)
index = self.list_indexes(client, collection_name)[0]
assert sorted(index) == sorted(['embeddings', 'embeddings_1', 'embeddings_2', 'embeddings_3'])
if self.has_collection(client, collection_name)[0]:
self.drop_collection(client, collection_name)
@pytest.mark.tags(CaseLabel.L1)
def test_milvus_client_array_insert_search(self):
"""

View File

@ -207,6 +207,72 @@ class TestMilvusClientIndexInvalid(TestMilvusClientV2Base):
check_task=CheckTasks.err_res, check_items=error)
self.drop_collection(client, collection_name)
@pytest.mark.tags(CaseLabel.L1)
@pytest.mark.parametrize("not_supported_index", ct.all_index_types[:-2])
def test_milvus_client_int8_vector_create_not_supported_cpu_index(self, not_supported_index):
"""
target: test create non-supported index on int8 vector
method: create non-supported index on int8 vector
expected: raise exception
"""
if not_supported_index in ct.int8_vector_index:
pytest.skip("This index is supported by int8 vector")
client = self._client()
collection_name = cf.gen_unique_str(prefix)
dim = 128
# 1. create collection
schema = self.create_schema(client, enable_dynamic_field=False)[0]
schema.add_field("id_string", DataType.VARCHAR, max_length=64, is_primary=True, auto_id=False)
schema.add_field("embeddings", DataType.INT8_VECTOR, dim=dim)
index_params = self.prepare_index_params(client)[0]
index_params.add_index("embeddings", metric_type="COSINE")
# 2. index_params.add_index("title")
self.create_collection(client, collection_name, dimension=dim, schema=schema, index_params=index_params)
self.release_collection(client, collection_name)
self.drop_index(client, collection_name, "embeddings")
# 3. prepare index params
index_params = self.prepare_index_params(client)[0]
index_params.add_index(field_name="embeddings", index_type=not_supported_index, metric_type="L2")
# 4. create another index
error = {ct.err_code: 1100, ct.err_msg: f"data type Int8Vector can't build with this index {not_supported_index}: "
f"invalid parameter[expected=valid index params][actual=invalid index params]"}
self.create_index(client, collection_name, index_params,
check_task=CheckTasks.err_res, check_items=error)
self.drop_collection(client, collection_name)
@pytest.mark.tags(CaseLabel.L1)
@pytest.mark.parametrize("not_supported_index", ct.all_index_types[-2:])
def test_milvus_client_int8_vector_create_not_supported_GPU_index(self, not_supported_index):
"""
target: test create non-supported index on int8 vector
method: create non-supported index on int8 vector
expected: raise exception
"""
if not_supported_index in ct.int8_vector_index:
pytest.skip("This index is supported by int8 vector")
client = self._client()
collection_name = cf.gen_unique_str(prefix)
dim = 128
# 1. create collection
schema = self.create_schema(client, enable_dynamic_field=False)[0]
schema.add_field("id_string", DataType.VARCHAR, max_length=64, is_primary=True, auto_id=False)
schema.add_field("embeddings", DataType.INT8_VECTOR, dim=dim)
index_params = self.prepare_index_params(client)[0]
index_params.add_index("embeddings", metric_type="COSINE")
# 2. index_params.add_index("title")
self.create_collection(client, collection_name, dimension=dim, schema=schema, index_params=index_params)
self.release_collection(client, collection_name)
self.drop_index(client, collection_name, "embeddings")
# 3. prepare index params
index_params = self.prepare_index_params(client)[0]
index_params.add_index(field_name="embeddings", index_type=not_supported_index, metric_type="L2")
# 4. create another index
error = {ct.err_code: 1100, ct.err_msg: f"invalid parameter[expected=valid index][actual=invalid "
f"index type: {not_supported_index}"}
self.create_index(client, collection_name, index_params,
check_task=CheckTasks.err_res, check_items=error)
self.drop_collection(client, collection_name)
class TestMilvusClientIndexValid(TestMilvusClientV2Base):
""" Test case of index interface """

View File

@ -332,6 +332,15 @@ class TestMilvusClientInsertValid(TestMilvusClientV2Base):
def metric_type(self, request):
yield request.param
@pytest.fixture(scope="function", params=[True, False])
def nullable(self, request):
yield request.param
@pytest.fixture(scope="function", params=[DataType.FLOAT_VECTOR, DataType.FLOAT16_VECTOR,
DataType.BFLOAT16_VECTOR, DataType.INT8_VECTOR])
def vector_type(self, request):
yield request.param
"""
******************************************************************
# The following are valid base cases
@ -339,7 +348,7 @@ class TestMilvusClientInsertValid(TestMilvusClientV2Base):
"""
@pytest.mark.tags(CaseLabel.L0)
def test_milvus_client_insert_default(self):
def test_milvus_client_insert_default(self, vector_type, nullable):
"""
target: test search (high level api) normal case
method: create connection, collection, insert and search
@ -348,22 +357,25 @@ class TestMilvusClientInsertValid(TestMilvusClientV2Base):
client = self._client()
collection_name = cf.gen_unique_str(prefix)
# 1. create collection
self.create_collection(client, collection_name, default_dim, consistency_level="Strong")
collections = self.list_collections(client)[0]
assert collection_name in collections
self.describe_collection(client, collection_name,
check_task=CheckTasks.check_describe_collection_property,
check_items={"collection_name": collection_name,
"dim": default_dim,
"consistency_level": 0})
dim = 8
# 1. create collection
schema = self.create_schema(client, enable_dynamic_field=False)[0]
schema.add_field(default_primary_key_field_name, DataType.INT64, max_length=64, is_primary=True, auto_id=False)
schema.add_field(default_vector_field_name, vector_type, dim=dim)
schema.add_field(default_string_field_name, DataType.VARCHAR, max_length=64, is_partition_key=True)
schema.add_field(default_float_field_name, DataType.FLOAT, nullable=nullable)
index_params = self.prepare_index_params(client)[0]
index_params.add_index(default_vector_field_name, metric_type="COSINE")
self.create_collection(client, collection_name, dimension=dim, schema=schema, index_params=index_params)
# 2. insert
rng = np.random.default_rng(seed=19530)
rows = [{default_primary_key_field_name: i, default_vector_field_name: list(rng.random((1, default_dim))[0]),
vectors = cf.gen_vectors(default_nb, dim, vector_data_type=vector_type)
rows = [{default_primary_key_field_name: i, default_vector_field_name: vectors[i],
default_float_field_name: i * 1.0, default_string_field_name: str(i)} for i in range(default_nb)]
results = self.insert(client, collection_name, rows)[0]
assert results['insert_count'] == default_nb
# 3. search
vectors_to_search = rng.random((1, default_dim))
vectors_to_search = [vectors[0]]
insert_ids = [i for i in range(default_nb)]
self.search(client, collection_name, vectors_to_search,
check_task=CheckTasks.check_search_results,
@ -377,7 +389,8 @@ class TestMilvusClientInsertValid(TestMilvusClientV2Base):
check_task=CheckTasks.check_query_results,
check_items={exp_res: rows,
"with_vec": True,
"pk_name": default_primary_key_field_name})
"pk_name": default_primary_key_field_name,
"vector_type": vector_type})
self.release_collection(client, collection_name)
self.drop_collection(client, collection_name)

View File

@ -743,8 +743,11 @@ class TestCollectionSearchInvalid(TestcaseBase):
check_items={"err_code": 101,
"err_msg": err_msg})
# 3. search collection without data after load
collection_w.create_index(
ct.default_float_vec_field_name, index_params=ct.default_flat_index)
if vector_data_type == DataType.INT8_VECTOR:
collection_w.create_index(ct.default_float_vec_field_name,
index_params={"index_type": "HNSW", "metric_type": "L2"})
else:
collection_w.create_index(ct.default_float_vec_field_name, index_params=ct.default_flat_index)
collection_w.load()
collection_w.search(vectors[:default_nq], default_search_field, default_search_params,
default_limit, default_search_exp,