mirror of
https://gitee.com/milvus-io/milvus.git
synced 2026-01-07 19:31:51 +08:00
Add collection cases (#5239)
Add a part of the test case for the Collection interface See also: #5224 #5231 Signed-off-by: ThreadDao yufen.zong@zilliz.com
This commit is contained in:
parent
001795ee70
commit
1dabbfc835
@ -1,5 +1,6 @@
|
||||
from utils.util_log import test_log as log
|
||||
from common.common_type import *
|
||||
from pymilvus_orm import Collection
|
||||
|
||||
|
||||
class CheckFunc:
|
||||
@ -16,7 +17,7 @@ class CheckFunc:
|
||||
self.keys = self.params.keys()
|
||||
|
||||
def run(self):
|
||||
check_result = True
|
||||
check_result = None
|
||||
|
||||
if self.check_res is None:
|
||||
pass
|
||||
@ -28,6 +29,8 @@ class CheckFunc:
|
||||
elif self.check_res == CheckParams.list_count and self.check_params is not None:
|
||||
check_result = self.check_list_count(self.res, self.func_name, self.check_params)
|
||||
|
||||
elif self.check_res == CheckParams.collection_property_check:
|
||||
check_result = self.req_collection_property_check(self.res, self.func_name, self.params)
|
||||
return check_result
|
||||
|
||||
@staticmethod
|
||||
@ -156,3 +159,18 @@ class CheckFunc:
|
||||
assert message == "Invalid partition tag: %s. The length of a partition tag must be less than 255 characters." % str(params)
|
||||
|
||||
return True
|
||||
|
||||
@staticmethod
|
||||
def req_collection_property_check(collection, func_name, params):
|
||||
'''
|
||||
:param collection
|
||||
:return:
|
||||
'''
|
||||
exp_func_name = "collection_init"
|
||||
if func_name != exp_func_name:
|
||||
log.warning("The function name is {} rather than {}".format(func_name, exp_func_name))
|
||||
if not isinstance(collection, Collection):
|
||||
raise Exception("The result to check isn't collection type object")
|
||||
assert collection.name == params["name"]
|
||||
assert collection.description == params["schema"].description
|
||||
assert collection.schema == params["schema"]
|
||||
@ -5,16 +5,61 @@ import numpy as np
|
||||
from sklearn import preprocessing
|
||||
|
||||
from pymilvus_orm.types import DataType
|
||||
from pymilvus_orm.schema import CollectionSchema, FieldSchema
|
||||
from utils.util_log import test_log as log
|
||||
from common.common_type import *
|
||||
|
||||
|
||||
"""" Methods of processing data """
|
||||
l2 = lambda x, y: np.linalg.norm(np.array(x) - np.array(y))
|
||||
|
||||
get_unique_str = "test_" + "".join(random.choice(string.ascii_letters + string.digits) for _ in range(8))
|
||||
|
||||
|
||||
def gen_int64_field(is_primary=False):
|
||||
description = "int64 type field"
|
||||
int64_field = FieldSchema(name=default_int64_field, dtype=DataType.INT64, description=description,
|
||||
is_primary=is_primary)
|
||||
return int64_field
|
||||
|
||||
|
||||
def gen_float_field(is_primary=False):
|
||||
description = "float type field"
|
||||
float_field = FieldSchema(name=default_float_field, dtype=DataType.FLOAT, description=description,
|
||||
is_primary=is_primary)
|
||||
return float_field
|
||||
|
||||
|
||||
def gen_float_vec_field(is_primary=False):
|
||||
description = "float vector type field"
|
||||
float_vec_field = FieldSchema(name=default_float_vec_field_name, dtype=DataType.FLOAT_VECTOR,
|
||||
description=description, dim=default_dim, is_primary=is_primary)
|
||||
return float_vec_field
|
||||
|
||||
|
||||
def gen_binary_vec_field(is_primary=False):
|
||||
description = "binary vector type field"
|
||||
binary_vec_field = FieldSchema(name=default_binary_vec_field_name, dtype=DataType.BINARY_VECTOR,
|
||||
description=description, is_primary=is_primary)
|
||||
return binary_vec_field
|
||||
|
||||
|
||||
def gen_default_collection_schema():
|
||||
fields = [gen_int64_field(), gen_float_field(), gen_float_vec_field()]
|
||||
schema = CollectionSchema(fields=fields, description="default collection")
|
||||
return schema
|
||||
|
||||
|
||||
def gen_collection_schema(fields, description="collection", **kwargs):
|
||||
schema = CollectionSchema(fields=fields, description=description, **kwargs)
|
||||
return schema
|
||||
|
||||
|
||||
def gen_default_binary_collection_schema():
|
||||
fields = [gen_int64_field(), gen_float_field(), gen_binary_vec_field()]
|
||||
binary_schema = CollectionSchema(fields=fields, description="default binary collection")
|
||||
return binary_schema
|
||||
|
||||
|
||||
def get_binary_default_fields(auto_id=True):
|
||||
default_fields = {
|
||||
"fields": [
|
||||
|
||||
@ -13,6 +13,9 @@ max_top_k = 16384
|
||||
max_partition_num = 4096 # 256
|
||||
default_segment_row_limit = 1000
|
||||
default_server_segment_row_limit = 1024 * 512
|
||||
default_alias = "default"
|
||||
default_int64_field = "int64"
|
||||
default_float_field = "float"
|
||||
default_float_vec_field_name = "float_vector"
|
||||
default_binary_vec_field_name = "binary_vector"
|
||||
default_partition_name = "_default"
|
||||
@ -21,8 +24,19 @@ row_count = "row_count"
|
||||
|
||||
|
||||
"""" List of parameters used to pass """
|
||||
get_invalid_strs = [[], 1, [1, "2", 3], (1,), {1: 1}, None, "12-s", "12 s", "(mn)", "中文", "%$#",
|
||||
"a".join("a" for i in range(256))]
|
||||
get_invalid_strs = [
|
||||
[],
|
||||
1,
|
||||
[1, "2", 3],
|
||||
(1,),
|
||||
{1: 1},
|
||||
None,
|
||||
"12-s",
|
||||
"12 s",
|
||||
"(mn)",
|
||||
"中文",
|
||||
"%$#",
|
||||
"a".join("a" for i in range(256))]
|
||||
|
||||
|
||||
""" Specially defined list """
|
||||
@ -48,8 +62,8 @@ class CheckParams:
|
||||
""" The name of the method used to check the result """
|
||||
cname_param_check = "collection_name_param_check"
|
||||
pname_param_check = "partition_name_param_check"
|
||||
|
||||
list_count = "check_list_count"
|
||||
collection_property_check = "collection_property_check"
|
||||
|
||||
|
||||
class CaseLabel:
|
||||
|
||||
@ -2,11 +2,118 @@ import pytest
|
||||
from base.client_request import ApiReq
|
||||
from utils.util_log import test_log as log
|
||||
from common.common_type import *
|
||||
from common.common_func import *
|
||||
|
||||
default_schema = gen_default_collection_schema()
|
||||
|
||||
|
||||
class TestCollection(ApiReq):
|
||||
class TestCollectionParams(ApiReq):
|
||||
""" Test case of collection interface """
|
||||
|
||||
def teardown_method(self):
|
||||
if self.collection.collection is not None:
|
||||
self.collection.drop()
|
||||
|
||||
@pytest.fixture(
|
||||
scope="function",
|
||||
params=get_invalid_strs
|
||||
)
|
||||
def get_invalid_string(self, request):
|
||||
yield request.param
|
||||
|
||||
# #5224
|
||||
@pytest.mark.tags(CaseLabel.L3)
|
||||
def test_case(self):
|
||||
log.info("Test case of collection interface")
|
||||
def test_collection(self):
|
||||
"""
|
||||
target: test collection with default schema
|
||||
method: create collection with default schema
|
||||
expected: assert collection property
|
||||
"""
|
||||
self._connect()
|
||||
c_name = get_unique_str
|
||||
collection, _ = self.collection.collection_init(c_name, data=None, schema=default_schema)
|
||||
assert collection.name == c_name
|
||||
assert collection.description == default_schema.description
|
||||
assert collection.schema == default_schema
|
||||
assert collection.is_empty
|
||||
assert collection.num_entities == 0
|
||||
assert collection.primary_field is None
|
||||
assert c_name in self.utility.list_collections()
|
||||
|
||||
def test_collection_empty_name(self):
|
||||
"""
|
||||
target: test collection with empty name
|
||||
method: create collection with a empty name
|
||||
expected: raise exception
|
||||
"""
|
||||
self._connect()
|
||||
c_name = ""
|
||||
ex, check = self.collection.collection_init(c_name, schema=default_schema)
|
||||
assert "value is illegal" in str(ex)
|
||||
|
||||
def test_collection_invalid_name(self, get_invalid_string):
|
||||
"""
|
||||
target: test collection with invalid name
|
||||
method: create collection with invalid name
|
||||
expected: raise exception
|
||||
"""
|
||||
self._connect()
|
||||
c_name = get_invalid_string
|
||||
ex, check = self.collection.collection_init(c_name, schema=default_schema)
|
||||
assert "invalid" or "illegal" in str(ex)
|
||||
|
||||
# #5231 TODO
|
||||
def test_collection_dup_name(self):
|
||||
"""
|
||||
target: test collection with dup name
|
||||
method: create collection with dup name and none schema and data
|
||||
expected: collection properties consistent
|
||||
"""
|
||||
self._connect()
|
||||
c_name = get_unique_str
|
||||
collection, _ = self.collection.collection_init(c_name, data=None, schema=default_schema)
|
||||
assert collection.name == c_name
|
||||
dup_collection, _ = self.collection.collection_init(c_name)
|
||||
assert c_name, c_name in self.utility.list_collections()
|
||||
assert collection.name == dup_collection.name
|
||||
# log.debug(collection.schema)
|
||||
# log.debug(dup_collection.schema)
|
||||
# assert collection.schema == dup_collection.schema
|
||||
|
||||
def test_collection_dup_name_new_schema(self):
|
||||
"""
|
||||
target: test collection with dup name and new schema
|
||||
method: 1.create collection with default schema 2. collection with dup name and new schema
|
||||
expected: raise exception
|
||||
"""
|
||||
self._connect()
|
||||
c_name = get_unique_str
|
||||
collection, _ = self.collection.collection_init(c_name, data=None, schema=default_schema)
|
||||
assert collection.name == c_name
|
||||
fields = [gen_int64_field()]
|
||||
schema = gen_collection_schema(fields=fields)
|
||||
ex, _ = self.collection.collection_init(c_name, schema=schema)
|
||||
assert "The collection already exist, but the schema isnot the same as the passed in" in str(ex)
|
||||
|
||||
|
||||
class TestCollectionOperation(ApiReq):
|
||||
"""
|
||||
******************************************************************
|
||||
The following cases are used to test collection interface operations
|
||||
******************************************************************
|
||||
"""
|
||||
|
||||
def test_collection_without_connection(self):
|
||||
"""
|
||||
target: test collection without connection
|
||||
method: 1.create collection after connection removed
|
||||
expected: raise exception
|
||||
"""
|
||||
self._connect()
|
||||
self.connection.remove_connection(default_alias)
|
||||
res_list = self.connection.list_connections()
|
||||
assert len(res_list) == 0
|
||||
c_name = get_unique_str
|
||||
ex, check = self.collection.collection_init(c_name, schema=default_schema)
|
||||
assert "no connection" in str(ex)
|
||||
assert self.collection is None
|
||||
|
||||
@ -33,5 +33,3 @@ class TestConnection(ApiReq):
|
||||
self.connection.configure(check_res='', check_params=None, default={"host": "192.168.1.240", "port": "19530"})
|
||||
self.connection.get_connection(alias='default')
|
||||
self.connection.list_connections(check_res=CheckParams.list_count, check_params={"list_count": 1})
|
||||
|
||||
|
||||
|
||||
@ -1,4 +1,5 @@
|
||||
import logging
|
||||
import sys
|
||||
|
||||
from config.test_info import test_info
|
||||
|
||||
@ -13,18 +14,24 @@ class TestLog:
|
||||
self.log.setLevel(logging.DEBUG)
|
||||
|
||||
try:
|
||||
formatter = logging.Formatter("[%(asctime)s - %(levelname)s - %(name)s]: %(message)s (%(filename)s:%(lineno)s)")
|
||||
fh = logging.FileHandler(log_file)
|
||||
fh.setLevel(logging.DEBUG)
|
||||
formatter = logging.Formatter("%(asctime)s - %(name)s - %(levelname)s - %(message)s")
|
||||
# formatter = logging.Formatter("%(asctime)s - %(name)s - %(levelname)s - %(message)s")
|
||||
fh.setFormatter(formatter)
|
||||
self.log.addHandler(fh)
|
||||
|
||||
eh = logging.FileHandler(log_err)
|
||||
eh.setLevel(logging.ERROR)
|
||||
formatter = logging.Formatter("%(asctime)s - %(name)s - %(levelname)s - %(message)s")
|
||||
# formatter = logging.Formatter("%(asctime)s - %(name)s - %(levelname)s - %(message)s")
|
||||
eh.setFormatter(formatter)
|
||||
self.log.addHandler(eh)
|
||||
|
||||
ch = logging.StreamHandler(sys.stdout)
|
||||
ch.setLevel(logging.DEBUG)
|
||||
ch.setFormatter(formatter)
|
||||
self.log.addHandler(ch)
|
||||
|
||||
except Exception as e:
|
||||
print("Can not use %s or %s to log." % (log_file, log_err))
|
||||
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user