mirror of
https://gitee.com/milvus-io/milvus.git
synced 2026-01-07 19:31:51 +08:00
Add mix base
Signed-off-by: Cai Yudong <yudong.cai@zilliz.com>
This commit is contained in:
parent
aee7c74762
commit
ee9f67c216
@ -19,19 +19,18 @@ class TestConnect:
|
||||
else:
|
||||
return False
|
||||
|
||||
# TODO: remove
|
||||
def _test_disconnect(self, connect):
|
||||
@pytest.mark.tags("0331")
|
||||
def test_close(self, connect):
|
||||
'''
|
||||
target: test disconnect
|
||||
method: disconnect a connected client
|
||||
expected: connect failed after disconnected
|
||||
'''
|
||||
res = connect.close()
|
||||
connect.close()
|
||||
with pytest.raises(Exception) as e:
|
||||
res = connect.list_collections()
|
||||
connect.list_collections()
|
||||
|
||||
# TODO: remove
|
||||
def _test_disconnect_repeatedly(self, dis_connect, args):
|
||||
def test_close_repeatedly(self, dis_connect, args):
|
||||
'''
|
||||
target: test disconnect repeatedly
|
||||
method: disconnect a connected client, disconnect again
|
||||
@ -48,7 +47,6 @@ class TestConnect:
|
||||
expected: connected is True
|
||||
'''
|
||||
milvus = get_milvus(args["ip"], args["port"], handler=args["handler"])
|
||||
# assert milvus.connected()
|
||||
|
||||
# TODO: Currently we test with remote IP, localhost testing need to add
|
||||
def _test_connect_ip_localhost(self, args):
|
||||
@ -62,6 +60,7 @@ class TestConnect:
|
||||
# assert milvus.connected()
|
||||
|
||||
@pytest.mark.timeout(CONNECT_TIMEOUT)
|
||||
@pytest.mark.tags("0331")
|
||||
def test_connect_wrong_ip_null(self, args):
|
||||
'''
|
||||
target: test connect with wrong ip value
|
||||
@ -70,9 +69,9 @@ class TestConnect:
|
||||
'''
|
||||
ip = ""
|
||||
with pytest.raises(Exception) as e:
|
||||
milvus = get_milvus(ip, args["port"], args["handler"])
|
||||
# assert not milvus.connected()
|
||||
get_milvus(ip, args["port"], args["handler"])
|
||||
|
||||
@pytest.mark.tags("0331")
|
||||
def test_connect_uri(self, args):
|
||||
'''
|
||||
target: test connect with correct uri
|
||||
@ -81,8 +80,8 @@ class TestConnect:
|
||||
'''
|
||||
uri_value = "tcp://%s:%s" % (args["ip"], args["port"])
|
||||
milvus = get_milvus(args["ip"], args["port"], uri=uri_value, handler=args["handler"])
|
||||
# assert milvus.connected()
|
||||
|
||||
@pytest.mark.tags("0331")
|
||||
def test_connect_uri_null(self, args):
|
||||
'''
|
||||
target: test connect with null uri
|
||||
@ -92,28 +91,28 @@ class TestConnect:
|
||||
uri_value = ""
|
||||
if self.local_ip(args):
|
||||
milvus = get_milvus(None, None, uri=uri_value, handler=args["handler"])
|
||||
# assert milvus.connected()
|
||||
else:
|
||||
with pytest.raises(Exception) as e:
|
||||
milvus = get_milvus(None, None, uri=uri_value, handler=args["handler"])
|
||||
# assert not milvus.connected()
|
||||
|
||||
@pytest.mark.tags("0331")
|
||||
def test_connect_with_multiprocess(self, args):
|
||||
'''
|
||||
target: test uri connect with multiprocess
|
||||
method: set correct uri, test with multiprocessing connecting
|
||||
expected: all connection is connected
|
||||
'''
|
||||
processes = []
|
||||
def connect():
|
||||
milvus = get_milvus(args["ip"], args["port"], handler=args["handler"])
|
||||
assert milvus
|
||||
assert milvus
|
||||
|
||||
with concurrent.futures.ThreadPoolExecutor(max_workers=20) as executor:
|
||||
future_results = {executor.submit(
|
||||
connect): i for i in range(100)}
|
||||
for future in concurrent.futures.as_completed(future_results):
|
||||
future.result()
|
||||
|
||||
@pytest.mark.tags("0331")
|
||||
def test_connect_repeatedly(self, args):
|
||||
'''
|
||||
target: test connect repeatedly
|
||||
@ -122,10 +121,7 @@ class TestConnect:
|
||||
'''
|
||||
uri_value = "tcp://%s:%s" % (args["ip"], args["port"])
|
||||
milvus = Milvus(uri=uri_value, handler=args["handler"])
|
||||
# milvus.connect(uri=uri_value, timeout=5)
|
||||
# milvus.connect(uri=uri_value, timeout=5)
|
||||
milvus = Milvus(uri=uri_value, handler=args["handler"])
|
||||
# assert milvus.connected()
|
||||
|
||||
def _test_add_vector_and_disconnect_concurrently(self):
|
||||
'''
|
||||
@ -153,19 +149,20 @@ class TestConnect:
|
||||
pass
|
||||
|
||||
def _test_thread_safe_with_one_connection_shared_in_multi_threads(self):
|
||||
'''
|
||||
'''
|
||||
Target: test 1 connection thread safe
|
||||
Method: 1 connection shared in multi-threads, all adding vectors, or other things
|
||||
Expected: Functional as one thread
|
||||
|
||||
'''
|
||||
pass
|
||||
pass
|
||||
|
||||
|
||||
class TestConnectIPInvalid(object):
|
||||
"""
|
||||
Test connect server with invalid ip
|
||||
"""
|
||||
|
||||
@pytest.fixture(
|
||||
scope="function",
|
||||
params=gen_invalid_ips()
|
||||
@ -175,11 +172,11 @@ class TestConnectIPInvalid(object):
|
||||
|
||||
@pytest.mark.level(2)
|
||||
@pytest.mark.timeout(CONNECT_TIMEOUT)
|
||||
@pytest.mark.tags("0331")
|
||||
def test_connect_with_invalid_ip(self, args, get_invalid_ip):
|
||||
ip = get_invalid_ip
|
||||
with pytest.raises(Exception) as e:
|
||||
milvus = get_milvus(ip, args["port"], args["handler"])
|
||||
# assert not milvus.connected()
|
||||
|
||||
|
||||
class TestConnectPortInvalid(object):
|
||||
@ -196,6 +193,7 @@ class TestConnectPortInvalid(object):
|
||||
|
||||
@pytest.mark.level(2)
|
||||
@pytest.mark.timeout(CONNECT_TIMEOUT)
|
||||
@pytest.mark.tags("0331")
|
||||
def test_connect_with_invalid_port(self, args, get_invalid_port):
|
||||
'''
|
||||
target: test ip:port connect with invalid port value
|
||||
@ -205,13 +203,13 @@ class TestConnectPortInvalid(object):
|
||||
port = get_invalid_port
|
||||
with pytest.raises(Exception) as e:
|
||||
milvus = get_milvus(args["ip"], port, args["handler"])
|
||||
# assert not milvus.connected()
|
||||
|
||||
|
||||
class TestConnectURIInvalid(object):
|
||||
"""
|
||||
Test connect server with invalid uri
|
||||
"""
|
||||
|
||||
@pytest.fixture(
|
||||
scope="function",
|
||||
params=gen_invalid_uris()
|
||||
@ -221,6 +219,7 @@ class TestConnectURIInvalid(object):
|
||||
|
||||
@pytest.mark.level(2)
|
||||
@pytest.mark.timeout(CONNECT_TIMEOUT)
|
||||
@pytest.mark.tags("0331")
|
||||
def test_connect_with_invalid_uri(self, get_invalid_uri, args):
|
||||
'''
|
||||
target: test uri connect with invalid uri value
|
||||
@ -230,4 +229,3 @@ class TestConnectURIInvalid(object):
|
||||
uri_value = get_invalid_uri
|
||||
with pytest.raises(Exception) as e:
|
||||
milvus = get_milvus(uri=uri_value, handler=args["handler"])
|
||||
# assert not milvus.connected()
|
||||
|
||||
@ -18,27 +18,50 @@ nprobe = 1
|
||||
epsilon = 0.001
|
||||
nlist = 128
|
||||
# index_params = {'index_type': IndexType.IVFLAT, 'nlist': 16384}
|
||||
default_index = {"index_type": "IVF_FLAT", "params": {"nlist": 16384}, "metric_type": "L2"}
|
||||
|
||||
|
||||
class TestMixBase:
|
||||
# TODO
|
||||
def _test_mix_base(self, connect, collection):
|
||||
nb = 200000
|
||||
nq = 5
|
||||
entities = gen_entities(nb=nb)
|
||||
ids = connect.insert(collection, entities)
|
||||
assert len(ids) == nb
|
||||
connect.flush([collection])
|
||||
connect.create_index(collection, default_float_vec_field_name, default_index)
|
||||
index = connect.describe_index(collection, default_float_vec_field_name)
|
||||
assert index == default_index
|
||||
query, vecs = gen_query_vectors(default_float_vec_field_name, entities, default_top_k, nq)
|
||||
connect.load_collection(collection)
|
||||
res = connect.search(collection, query)
|
||||
assert len(res) == nq
|
||||
assert len(res[0]) == default_top_k
|
||||
assert res[0]._distances[0] <= epsilon
|
||||
assert check_id_result(res[0], ids[0])
|
||||
|
||||
# disable
|
||||
def _test_search_during_createIndex(self, args):
|
||||
loops = 10000
|
||||
collection = gen_unique_str()
|
||||
query_vecs = [vectors[0], vectors[1]]
|
||||
uri = "tcp://%s:%s" % (args["ip"], args["port"])
|
||||
id_0 = 0; id_1 = 0
|
||||
id_0 = 0;
|
||||
id_1 = 0
|
||||
milvus_instance = get_milvus(args["handler"])
|
||||
# milvus_instance.connect(uri=uri)
|
||||
milvus_instance.create_collection({'collection_name': collection,
|
||||
'dimension': default_dim,
|
||||
'index_file_size': index_file_size,
|
||||
'metric_type': "L2"})
|
||||
'dimension': default_dim,
|
||||
'index_file_size': index_file_size,
|
||||
'metric_type': "L2"})
|
||||
for i in range(10):
|
||||
status, ids = milvus_instance.bulk_insert(collection, vectors)
|
||||
# logging.getLogger().info(ids)
|
||||
if i == 0:
|
||||
id_0 = ids[0]; id_1 = ids[1]
|
||||
id_0 = ids[0];
|
||||
id_1 = ids[1]
|
||||
|
||||
# def create_index(milvus_instance):
|
||||
# logging.getLogger().info("In create index")
|
||||
# status = milvus_instance.create_index(collection, index_params)
|
||||
@ -49,6 +72,7 @@ class TestMixBase:
|
||||
logging.getLogger().info("In add vectors")
|
||||
status, ids = milvus_instance.bulk_insert(collection, vectors)
|
||||
logging.getLogger().info(status)
|
||||
|
||||
def search(milvus_instance):
|
||||
logging.getLogger().info("In search vectors")
|
||||
for i in range(loops):
|
||||
@ -56,13 +80,14 @@ class TestMixBase:
|
||||
logging.getLogger().info(status)
|
||||
assert result[0][0].id == id_0
|
||||
assert result[1][0].id == id_1
|
||||
|
||||
milvus_instance = get_milvus(args["handler"])
|
||||
# milvus_instance.connect(uri=uri)
|
||||
p_search = Process(target=search, args=(milvus_instance, ))
|
||||
p_search = Process(target=search, args=(milvus_instance,))
|
||||
p_search.start()
|
||||
milvus_instance = get_milvus(args["handler"])
|
||||
# milvus_instance.connect(uri=uri)
|
||||
p_create = Process(target=insert, args=(milvus_instance, ))
|
||||
p_create = Process(target=insert, args=(milvus_instance,))
|
||||
p_create.start()
|
||||
p_create.join()
|
||||
|
||||
@ -79,7 +104,7 @@ class TestMixBase:
|
||||
idx = []
|
||||
index_param = {'nlist': nlist}
|
||||
|
||||
#create collection and add vectors
|
||||
# create collection and add vectors
|
||||
for i in range(30):
|
||||
collection_name = gen_unique_str('test_mix_multi_collections')
|
||||
collection_list.append(collection_name)
|
||||
@ -123,7 +148,7 @@ class TestMixBase:
|
||||
status = connect.create_index(collection_list[50 + i], IndexType.IVF_SQ8, index_param)
|
||||
assert status.OK()
|
||||
|
||||
#describe index
|
||||
# describe index
|
||||
for i in range(10):
|
||||
status, result = connect.get_index_info(collection_list[i])
|
||||
assert result._index_type == IndexType.FLAT
|
||||
@ -138,7 +163,7 @@ class TestMixBase:
|
||||
status, result = connect.get_index_info(collection_list[50 + i])
|
||||
assert result._index_type == IndexType.IVF_SQ8
|
||||
|
||||
#search
|
||||
# search
|
||||
query_vecs = [vectors[0], vectors[10], vectors[20]]
|
||||
for i in range(60):
|
||||
collection = collection_list[i]
|
||||
@ -154,8 +179,18 @@ class TestMixBase:
|
||||
logging.getLogger().info(idx[3 * i + j])
|
||||
assert check_result(result[j], idx[3 * i + j])
|
||||
|
||||
|
||||
def check_result(result, id):
|
||||
if len(result) >= 5:
|
||||
return id in [result[0].id, result[1].id, result[2].id, result[3].id, result[4].id]
|
||||
else:
|
||||
return id in (i.id for i in result)
|
||||
|
||||
|
||||
def check_id_result(result, id):
|
||||
limit_in = 5
|
||||
ids = [entity.id for entity in result]
|
||||
if len(result) >= limit_in:
|
||||
return id in ids[:limit_in]
|
||||
else:
|
||||
return id in ids
|
||||
Loading…
x
Reference in New Issue
Block a user