add test insert binary (#3853)

Signed-off-by: ThreadDao <zongyufen@foxmail.com>
This commit is contained in:
ThreadDao 2020-09-24 14:11:13 +08:00 committed by GitHub
parent bbeb56f41e
commit a04a8ad1ff
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -1,3 +1,4 @@
import logging
import time
import pdb
import copy
@ -11,6 +12,7 @@ from constants import *
ADD_TIMEOUT = 60
uid = "test_insert"
field_name = default_float_vec_field_name
binary_field_name = default_binary_vec_field_name
default_single_query = {
"bool": {
"must": [
@ -20,6 +22,7 @@ default_single_query = {
}
}
class TestInsertBase:
"""
******************************************************************
@ -116,6 +119,11 @@ class TestInsertBase:
assert len(ids) == default_nb
connect.flush([collection])
connect.create_index(collection, field_name, get_simple_index)
info = connect.get_collection_info(collection)
fields = info["fields"]
for field in fields:
if field["field"] == field_name:
assert field["indexes"][0] == get_simple_index
@pytest.mark.timeout(ADD_TIMEOUT)
def test_insert_after_create_index(self, connect, collection, get_simple_index):
@ -127,6 +135,11 @@ class TestInsertBase:
connect.create_index(collection, field_name, get_simple_index)
ids = connect.insert(collection, default_entities)
assert len(ids) == default_nb
info = connect.get_collection_info(collection)
fields = info["fields"]
for field in fields:
if field["field"] == field_name:
assert field["indexes"][0] == get_simple_index
@pytest.mark.timeout(ADD_TIMEOUT)
def test_insert_search(self, connect, collection):
@ -300,6 +313,7 @@ class TestInsertBase:
connect.create_partition(collection, default_tag)
ids = connect.insert(collection, default_entities, partition_tag=default_tag)
assert len(ids) == default_nb
assert connect.has_partition(collection, default_tag)
@pytest.mark.timeout(ADD_TIMEOUT)
def test_insert_tag_with_ids(self, connect, id_collection):
@ -533,6 +547,96 @@ class TestInsertBase:
assert res[0] is None
class TestInsertBinary:
@pytest.fixture(
scope="function",
params=gen_binary_index()
)
def get_binary_index(self, request):
request.param["metric_type"] = "JACCARD"
return request.param
def test_insert_binary_entities(self, connect, binary_collection):
'''
target: test insert entities in binary collection
method: create collection and insert binary entities in it
expected: the collection row count equals to nb
'''
ids = connect.insert(binary_collection, default_binary_entities)
assert len(ids) == default_nb
connect.flush()
assert connect.count_entities(binary_collection) == default_nb
def test_insert_binary_tag(self, connect, binary_collection):
'''
target: test insert entities and create partition tag
method: create collection and insert binary entities in it, with the partition_tag param
expected: the collection row count equals to nb
'''
connect.create_partition(binary_collection, default_tag)
ids = connect.insert(binary_collection, default_binary_entities, partition_tag=default_tag)
assert len(ids) == default_nb
assert connect.has_partition(binary_collection, default_tag)
def test_insert_binary_multi_times(self, connect, binary_collection):
'''
target: test insert entities multi times and final flush
method: create collection and insert binary entity multi and final flush
expected: the collection row count equals to nb
'''
for i in range(default_nb):
ids = connect.insert(binary_collection, default_binary_entity)
assert len(ids) == 1
connect.flush([binary_collection])
assert connect.count_entities(binary_collection) == default_nb
def test_insert_binary_after_create_index(self, connect, binary_collection, get_binary_index):
'''
target: test insert binary entities after build index
method: build index and insert entities
expected: no error raised
'''
connect.create_index(binary_collection, binary_field_name, get_binary_index)
ids = connect.insert(binary_collection, default_binary_entities)
assert len(ids) == default_nb
connect.flush([binary_collection])
info = connect.get_collection_info(binary_collection)
fields = info["fields"]
for field in fields:
if field["field"] == binary_field_name:
assert field["indexes"][0] == get_binary_index
@pytest.mark.timeout(ADD_TIMEOUT)
def test_insert_binary_create_index(self, connect, binary_collection, get_binary_index):
'''
target: test build index insert after vector
method: insert vector and build index
expected: no error raised
'''
ids = connect.insert(binary_collection, default_binary_entities)
assert len(ids) == default_nb
connect.flush([binary_collection])
connect.create_index(binary_collection, binary_field_name, get_binary_index)
info = connect.get_collection_info(binary_collection)
fields = info["fields"]
for field in fields:
if field["field"] == binary_field_name:
assert field["indexes"][0] == get_binary_index
def test_insert_binary_search(self, connect, binary_collection):
'''
target: test search vector after insert vector after a while
method: insert vector, sleep, and search collection
expected: no error raised
'''
ids = connect.insert(binary_collection, default_binary_entities)
connect.flush([binary_collection])
query, vecs = gen_query_vectors(binary_field_name, default_binary_entities, default_top_k, 1, metric_type="JACCARD")
res = connect.search(binary_collection, query)
logging.getLogger().debug(res)
assert res
class TestInsertAsync:
@pytest.fixture(scope="function", autouse=True)
def skip_http_check(self, args):
@ -940,15 +1044,13 @@ class TestInsertInvalidBinary(object):
@pytest.mark.level(2)
def test_insert_with_invalid_field_name(self, connect, binary_collection, get_field_name):
field_name = get_field_name
tmp_entity = update_field_name(copy.deepcopy(default_binary_entity), "int64", get_field_name)
with pytest.raises(Exception):
connect.insert(binary_collection, tmp_entity)
@pytest.mark.level(2)
def test_insert_with_invalid_field_value(self, connect, binary_collection, get_field_int_value):
field_value = get_field_int_value
tmp_entity = update_field_type(copy.deepcopy(default_binary_entity), 'int64', field_value)
tmp_entity = update_field_type(copy.deepcopy(default_binary_entity), 'int64', get_field_int_value)
with pytest.raises(Exception):
connect.insert(binary_collection, tmp_entity)
@ -972,13 +1074,6 @@ class TestInsertInvalidBinary(object):
with pytest.raises(Exception):
connect.insert(binary_id_collection, default_binary_entities, ids)
@pytest.mark.level(2)
def test_insert_with_invalid_field_name(self, connect, binary_collection, get_field_name):
field_name = get_field_name
tmp_entity = update_field_name(copy.deepcopy(default_binary_entity), "int64", get_field_name)
with pytest.raises(Exception):
connect.insert(binary_collection, tmp_entity)
@pytest.mark.level(2)
def test_insert_with_invalid_field_type(self, connect, binary_collection, get_field_type):
field_type = get_field_type
@ -986,13 +1081,6 @@ class TestInsertInvalidBinary(object):
with pytest.raises(Exception):
connect.insert(binary_collection, tmp_entity)
@pytest.mark.level(2)
def test_insert_with_invalid_field_value(self, connect, binary_collection, get_field_int_value):
field_value = get_field_int_value
tmp_entity = update_field_type(copy.deepcopy(default_binary_entity), 'int64', field_value)
with pytest.raises(Exception):
connect.insert(binary_collection, tmp_entity)
@pytest.mark.level(2)
def test_insert_with_invalid_field_vector_value(self, connect, binary_collection, get_field_vectors_value):
tmp_entity = copy.deepcopy(default_binary_entities)