milvus/tests/milvus_python_test/test_search_vectors.py
BossZou f10f6cd5f4
Support run dev test with http handler in python SDK (#1116)
* refactoring(create_table done)

* refactoring

* refactor server delivery (insert done)

* refactoring server module (count_table done)

* server refactor done

* cmake pass

* refactor server module done.

* set grpc response status correctly

* format done.

* fix redefine ErrorMap()

* optimize insert reducing ids data copy

* optimize grpc request with reducing data copy

* clang format

* [skip ci] Refactor server module done. update changlog. prepare for PR

* remove explicit and change int32_t to int64_t

* add web server

* [skip ci] add license in web module

* modify header include & comment oatpp environment config

* add port configure & create table in handler

* modify web url

* simple url complation done & add swagger

* make sure web url

* web functionality done. debuging

* add web unittest

* web test pass

* add web server port

* add web server port in template

* update unittest cmake file

* change web server default port to 19121

* rename method in web module & unittest pass

* add search case in unittest for web module

* rename some variables

* fix bug

* unittest pass

* web prepare

* fix cmd bug(check server status)

* update changlog

* add web port validate & default set

* clang-format pass

* add web port test in unittest

* add CORS & redirect root to swagger ui

* add web status

* web table method func cascade test pass

* add config url in web module

* modify thirdparty cmake to avoid building oatpp test

* clang format

* update changlog

* add constants in web module

* reserve Config.cpp

* fix constants reference bug

* replace web server with async module

* modify component to support async

* format

* developing controller & add test clent into unittest

* add web port into demo/server_config

* modify thirdparty cmake to allow build test

* remove  unnecessary comment

* add endpoint info in controller

* finish web test(bug here)

* clang format

* add web test cpp to lint exclusions

* check null field in GetConfig

* add macro RETURN STATUS DTo

* fix cmake conflict

* fix crash when exit server

* remove surplus comments & add http param check

* add uri /docs to direct swagger

* format

* change cmd to system

* add default value & unittest in web module

* add macros to judge if GPU supported

* add macros in unit & add default in index dto & print error message when bind http port fail

* format (fix #788)

* fix cors bug (not completed)

* comment cors

* change web framework to simple api

* comments optimize

* change to simple API

* remove comments in controller.hpp

* remove EP_COMMON_CMAKE_ARGS in oatpp and oatpp-swagger

* add ep cmake args to sqlite

* clang-format

* change a format

* test pass

* change name to

* fix compiler issue(oatpp-swagger depend on oatpp)

* add & in start_server.h

* specify lib location with oatpp and oatpp-swagger

* add comments

* add swagger definition

* [skip ci] change http method options status code

* remove oatpp swagger(fix #970)

* remove comments

* check Start web behavior

* add default to cpu_cache_capacity

* remove swagger component.hpp & /docs url

* remove /docs info

* remove /docs in unittest

* remove space in test rpc

* remove repeate info in CHANGLOG

* change cache_insert_data default value as a constant

* [skip ci] Fix some broken links (#960)

* [skip ci] Fix broken link

* [skip ci] Fix broken link

* [skip ci] Fix broken link

* [skip ci] Fix broken links

* fix issue 373 (#964)

* fix issue 373

* Adjustment format

* Adjustment format

* Adjustment format

* change readme

* #966 update NOTICE.md (#967)

* remove comments

* check Start web behavior

* add default to cpu_cache_capacity

* remove swagger component.hpp & /docs url

* remove /docs info

* remove /docs in unittest

* remove space in test rpc

* remove repeate info in CHANGLOG

* change cache_insert_data default value as a constant

* adjust web port cofig place

* rename web_port variable

* change gpu resources invoke way to cmd()

* set advanced config name add DEFAULT

* change config setting to cmd

* modify ..

* optimize code

* assign TableDto' count default value 0 (fix #995)

* check if table exists when show partitions (fix #1028)

* check table exists when drop partition (fix #1029)

* check if partition name is legal (fix #1022)

* modify status code when partition tag is illegal

* update changlog

* add info to /system url

* add binary index and add bin uri & handler method(not completed)

* optimize http insert and search time(fix #1066) | add binary vectors support(fix #1067)

* fix test partition bug

* fix test bug when check insert records

* add binary vectors test

* add default for offset and page_size

* fix uinttest bug

* [skip ci] remove comments

* optimize web code for PR comments

* add new folder named utils

* check offset and pagesize (fix #1082)

* improve error message if offset or page_size is not legal (fix #1075)

* add log into web module

* update changlog

* check gpu sources setting when assign repeated value (fix #990)

* update changlog

* clang-format pass

* add default handler in http handler

* [skip ci] improve error msg when check gpu resources

* change check offset way

* remove func IsIntStr

* add case

* change int32 to int64 when check number str

* add log in we module(doing)

* update test case

* add log in web controller

* remove surplus dot

* add preload into /system/

* change get_milvus() to get_milvus(args['handler'])

* support load table into memory with http server (fix #1115)

* [skip ci] comment surplus dto in VectorDto

Co-authored-by: jielinxu <52057195+jielinxu@users.noreply.github.com>
Co-authored-by: JackLCL <53512883+JackLCL@users.noreply.github.com>
Co-authored-by: Cai Yudong <yudong.cai@zilliz.com>
2020-02-14 11:07:09 +08:00

1042 lines
42 KiB
Python

import pdb
import copy
import struct
import pytest
import threading
import datetime
import logging
from time import sleep
from multiprocessing import Process
import numpy
from milvus import IndexType, MetricType
from utils import *
dim = 128
table_id = "test_search"
add_interval_time = 2
vectors = gen_vectors(6000, dim)
# vectors /= numpy.linalg.norm(vectors)
# vectors = vectors.tolist()
nprobe = 1
epsilon = 0.001
tag = "1970-01-01"
raw_vectors, binary_vectors = gen_binary_vectors(6000, dim)
class TestSearchBase:
def init_data(self, connect, table, nb=6000):
'''
Generate vectors and add it in table, before search vectors
'''
global vectors
if nb == 6000:
add_vectors = vectors
else:
add_vectors = gen_vectors(nb, dim)
# add_vectors /= numpy.linalg.norm(add_vectors)
# add_vectors = add_vectors.tolist()
status, ids = connect.add_vectors(table, add_vectors)
sleep(add_interval_time)
return add_vectors, ids
def init_binary_data(self, connect, table, nb=6000, insert=True):
'''
Generate vectors and add it in table, before search vectors
'''
ids = []
global binary_vectors
global raw_vectors
if nb == 6000:
add_vectors = binary_vectors
add_raw_vectors = raw_vectors
else:
add_raw_vectors, add_vectors = gen_binary_vectors(nb, dim)
# add_vectors /= numpy.linalg.norm(add_vectors)
# add_vectors = add_vectors.tolist()
if insert is True:
status, ids = connect.add_vectors(table, add_vectors)
sleep(add_interval_time)
return add_raw_vectors, add_vectors, ids
"""
generate valid create_index params
"""
@pytest.fixture(
scope="function",
params=gen_index_params()
)
def get_index_params(self, request, connect):
if str(connect._cmd("mode")[1]) == "CPU":
if request.param["index_type"] == IndexType.IVF_SQ8H:
pytest.skip("sq8h not support in open source")
if request.param["index_type"] == IndexType.IVF_PQ:
pytest.skip("Skip PQ Temporary")
return request.param
@pytest.fixture(
scope="function",
params=gen_simple_index_params()
)
def get_simple_index_params(self, request, connect):
if str(connect._cmd("mode")[1]) == "CPU":
if request.param["index_type"] == IndexType.IVF_SQ8H:
pytest.skip("sq8h not support in open source")
if request.param["index_type"] == IndexType.IVF_PQ:
pytest.skip("Skip PQ Temporary")
return request.param
@pytest.fixture(
scope="function",
params=gen_simple_index_params()
)
def get_jaccard_index_params(self, request, connect):
logging.getLogger().info(request.param)
if request.param["index_type"] == IndexType.IVFLAT or request.param["index_type"] == IndexType.FLAT:
return request.param
else:
pytest.skip("Skip index Temporary")
@pytest.fixture(
scope="function",
params=gen_simple_index_params()
)
def get_hamming_index_params(self, request, connect):
logging.getLogger().info(request.param)
if request.param["index_type"] == IndexType.IVFLAT or request.param["index_type"] == IndexType.FLAT:
return request.param
else:
pytest.skip("Skip index Temporary")
"""
generate top-k params
"""
@pytest.fixture(
scope="function",
params=[1, 99, 1024, 2048, 2049]
)
def get_top_k(self, request):
yield request.param
def test_search_top_k_flat_index(self, connect, table, get_top_k):
'''
target: test basic search fuction, all the search params is corrent, change top-k value
method: search with the given vectors, check the result
expected: search status ok, and the length of the result is top_k
'''
vectors, ids = self.init_data(connect, table)
query_vec = [vectors[0]]
top_k = get_top_k
nprobe = 1
status, result = connect.search_vectors(table, top_k, nprobe, query_vec)
if top_k <= 2048:
assert status.OK()
assert len(result[0]) == min(len(vectors), top_k)
assert result[0][0].distance <= epsilon
assert check_result(result[0], ids[0])
else:
assert not status.OK()
def test_search_l2_index_params(self, connect, table, get_simple_index_params):
'''
target: test basic search fuction, all the search params is corrent, test all index params, and build
method: search with the given vectors, check the result
expected: search status ok, and the length of the result is top_k
'''
index_params = get_simple_index_params
logging.getLogger().info(index_params)
vectors, ids = self.init_data(connect, table)
status = connect.create_index(table, index_params)
query_vec = [vectors[0]]
top_k = 10
nprobe = 1
status, result = connect.search_vectors(table, top_k, nprobe, query_vec)
logging.getLogger().info(result)
if top_k <= 1024:
assert status.OK()
assert len(result[0]) == min(len(vectors), top_k)
assert check_result(result[0], ids[0])
assert result[0][0].distance <= epsilon
else:
assert not status.OK()
def test_search_l2_index_params_partition(self, connect, table, get_simple_index_params):
'''
target: test basic search fuction, all the search params is corrent, test all index params, and build
method: add vectors into table, search with the given vectors, check the result
expected: search status ok, and the length of the result is top_k, search table with partition tag return empty
'''
index_params = get_simple_index_params
logging.getLogger().info(index_params)
partition_name = gen_unique_str()
status = connect.create_partition(table, partition_name, tag)
vectors, ids = self.init_data(connect, table)
status = connect.create_index(table, index_params)
query_vec = [vectors[0]]
top_k = 10
nprobe = 1
status, result = connect.search_vectors(table, top_k, nprobe, query_vec)
logging.getLogger().info(result)
assert status.OK()
assert len(result[0]) == min(len(vectors), top_k)
assert check_result(result[0], ids[0])
assert result[0][0].distance <= epsilon
status, result = connect.search_vectors(table, top_k, nprobe, query_vec, partition_tags=[tag])
logging.getLogger().info(result)
assert status.OK()
assert len(result) == 0
def test_search_l2_index_params_partition_A(self, connect, table, get_simple_index_params):
'''
target: test basic search fuction, all the search params is corrent, test all index params, and build
method: search partition with the given vectors, check the result
expected: search status ok, and the length of the result is 0
'''
index_params = get_simple_index_params
logging.getLogger().info(index_params)
partition_name = gen_unique_str()
status = connect.create_partition(table, partition_name, tag)
vectors, ids = self.init_data(connect, table)
status = connect.create_index(table, index_params)
query_vec = [vectors[0]]
top_k = 10
nprobe = 1
status, result = connect.search_vectors(partition_name, top_k, nprobe, query_vec, partition_tags=[tag])
logging.getLogger().info(result)
assert status.OK()
assert len(result) == 0
def test_search_l2_index_params_partition_B(self, connect, table, get_simple_index_params):
'''
target: test basic search fuction, all the search params is corrent, test all index params, and build
method: search with the given vectors, check the result
expected: search status ok, and the length of the result is top_k
'''
index_params = get_simple_index_params
logging.getLogger().info(index_params)
partition_name = gen_unique_str()
status = connect.create_partition(table, partition_name, tag)
vectors, ids = self.init_data(connect, partition_name)
status = connect.create_index(table, index_params)
query_vec = [vectors[0]]
top_k = 10
nprobe = 1
status, result = connect.search_vectors(table, top_k, nprobe, query_vec)
logging.getLogger().info(result)
assert status.OK()
assert len(result[0]) == min(len(vectors), top_k)
assert check_result(result[0], ids[0])
assert result[0][0].distance <= epsilon
status, result = connect.search_vectors(table, top_k, nprobe, query_vec, partition_tags=[tag])
logging.getLogger().info(result)
assert status.OK()
assert len(result[0]) == min(len(vectors), top_k)
assert check_result(result[0], ids[0])
assert result[0][0].distance <= epsilon
status, result = connect.search_vectors(partition_name, top_k, nprobe, query_vec, partition_tags=[tag])
logging.getLogger().info(result)
assert status.OK()
assert len(result) == 0
def test_search_l2_index_params_partition_C(self, connect, table, get_simple_index_params):
'''
target: test basic search fuction, all the search params is corrent, test all index params, and build
method: search with the given vectors and tags (one of the tags not existed in table), check the result
expected: search status ok, and the length of the result is top_k
'''
index_params = get_simple_index_params
logging.getLogger().info(index_params)
partition_name = gen_unique_str()
status = connect.create_partition(table, partition_name, tag)
vectors, ids = self.init_data(connect, partition_name)
status = connect.create_index(table, index_params)
query_vec = [vectors[0]]
top_k = 10
nprobe = 1
status, result = connect.search_vectors(table, top_k, nprobe, query_vec, partition_tags=[tag, "new_tag"])
logging.getLogger().info(result)
assert status.OK()
assert len(result[0]) == min(len(vectors), top_k)
assert check_result(result[0], ids[0])
assert result[0][0].distance <= epsilon
def test_search_l2_index_params_partition_D(self, connect, table, get_simple_index_params):
'''
target: test basic search fuction, all the search params is corrent, test all index params, and build
method: search with the given vectors and tag (tag name not existed in table), check the result
expected: search status ok, and the length of the result is top_k
'''
index_params = get_simple_index_params
logging.getLogger().info(index_params)
partition_name = gen_unique_str()
status = connect.create_partition(table, partition_name, tag)
vectors, ids = self.init_data(connect, partition_name)
status = connect.create_index(table, index_params)
query_vec = [vectors[0]]
top_k = 10
nprobe = 1
status, result = connect.search_vectors(table, top_k, nprobe, query_vec, partition_tags=["new_tag"])
logging.getLogger().info(result)
assert status.OK()
assert len(result) == 0
def test_search_l2_index_params_partition_E(self, connect, table, get_simple_index_params):
'''
target: test basic search fuction, all the search params is corrent, test all index params, and build
method: search table with the given vectors and tags, check the result
expected: search status ok, and the length of the result is top_k
'''
new_tag = "new_tag"
index_params = get_simple_index_params
logging.getLogger().info(index_params)
partition_name = gen_unique_str()
new_partition_name = gen_unique_str()
status = connect.create_partition(table, partition_name, tag)
status = connect.create_partition(table, new_partition_name, new_tag)
vectors, ids = self.init_data(connect, partition_name)
new_vectors, new_ids = self.init_data(connect, new_partition_name, nb=6001)
status = connect.create_index(table, index_params)
query_vec = [vectors[0], new_vectors[0]]
top_k = 10
nprobe = 1
status, result = connect.search_vectors(table, top_k, nprobe, query_vec, partition_tags=[tag, new_tag])
logging.getLogger().info(result)
assert status.OK()
assert len(result[0]) == min(len(vectors), top_k)
assert check_result(result[0], ids[0])
assert check_result(result[1], new_ids[0])
assert result[0][0].distance <= epsilon
assert result[1][0].distance <= epsilon
status, result = connect.search_vectors(table, top_k, nprobe, query_vec, partition_tags=[new_tag])
logging.getLogger().info(result)
assert status.OK()
assert len(result[0]) == min(len(vectors), top_k)
assert check_result(result[1], new_ids[0])
assert result[1][0].distance <= epsilon
def test_search_l2_index_params_partition_F(self, connect, table, get_simple_index_params):
'''
target: test basic search fuction, all the search params is corrent, test all index params, and build
method: search table with the given vectors and tags with "re" expr, check the result
expected: search status ok, and the length of the result is top_k
'''
tag = "atag"
new_tag = "new_tag"
index_params = get_simple_index_params
logging.getLogger().info(index_params)
partition_name = gen_unique_str()
new_partition_name = gen_unique_str()
status = connect.create_partition(table, partition_name, tag)
status = connect.create_partition(table, new_partition_name, new_tag)
vectors, ids = self.init_data(connect, partition_name)
new_vectors, new_ids = self.init_data(connect, new_partition_name, nb=6001)
status = connect.create_index(table, index_params)
query_vec = [vectors[0], new_vectors[0]]
top_k = 10
nprobe = 1
status, result = connect.search_vectors(table, top_k, nprobe, query_vec, partition_tags=["new(.*)"])
logging.getLogger().info(result)
assert status.OK()
assert result[0][0].distance > epsilon
assert result[1][0].distance <= epsilon
status, result = connect.search_vectors(table, top_k, nprobe, query_vec, partition_tags=["(.*)tag"])
logging.getLogger().info(result)
assert status.OK()
assert result[0][0].distance <= epsilon
assert result[1][0].distance <= epsilon
def test_search_ip_index_params(self, connect, ip_table, get_simple_index_params):
'''
target: test basic search fuction, all the search params is corrent, test all index params, and build
method: search with the given vectors, check the result
expected: search status ok, and the length of the result is top_k
'''
index_params = get_simple_index_params
logging.getLogger().info(index_params)
vectors, ids = self.init_data(connect, ip_table)
status = connect.create_index(ip_table, index_params)
query_vec = [vectors[0]]
top_k = 10
nprobe = 1
status, result = connect.search_vectors(ip_table, top_k, nprobe, query_vec)
logging.getLogger().info(result)
if top_k <= 1024:
assert status.OK()
assert len(result[0]) == min(len(vectors), top_k)
assert check_result(result[0], ids[0])
assert abs(result[0][0].distance - numpy.inner(numpy.array(query_vec[0]), numpy.array(query_vec[0]))) <= gen_inaccuracy(result[0][0].distance)
else:
assert not status.OK()
def test_search_ip_index_params_partition(self, connect, ip_table, get_simple_index_params):
'''
target: test basic search fuction, all the search params is corrent, test all index params, and build
method: search with the given vectors, check the result
expected: search status ok, and the length of the result is top_k
'''
index_params = get_simple_index_params
logging.getLogger().info(index_params)
partition_name = gen_unique_str()
status = connect.create_partition(ip_table, partition_name, tag)
vectors, ids = self.init_data(connect, ip_table)
status = connect.create_index(ip_table, index_params)
query_vec = [vectors[0]]
top_k = 10
nprobe = 1
status, result = connect.search_vectors(ip_table, top_k, nprobe, query_vec)
logging.getLogger().info(result)
assert status.OK()
assert len(result[0]) == min(len(vectors), top_k)
assert check_result(result[0], ids[0])
assert abs(result[0][0].distance - numpy.inner(numpy.array(query_vec[0]), numpy.array(query_vec[0]))) <= gen_inaccuracy(result[0][0].distance)
status, result = connect.search_vectors(ip_table, top_k, nprobe, query_vec, partition_tags=[tag])
logging.getLogger().info(result)
assert status.OK()
assert len(result) == 0
def test_search_ip_index_params_partition_A(self, connect, ip_table, get_simple_index_params):
'''
target: test basic search fuction, all the search params is corrent, test all index params, and build
method: search with the given vectors and tag, check the result
expected: search status ok, and the length of the result is top_k
'''
index_params = get_simple_index_params
logging.getLogger().info(index_params)
partition_name = gen_unique_str()
status = connect.create_partition(ip_table, partition_name, tag)
vectors, ids = self.init_data(connect, partition_name)
status = connect.create_index(ip_table, index_params)
query_vec = [vectors[0]]
top_k = 10
nprobe = 1
status, result = connect.search_vectors(ip_table, top_k, nprobe, query_vec, partition_tags=[tag])
logging.getLogger().info(result)
assert status.OK()
assert len(result[0]) == min(len(vectors), top_k)
assert check_result(result[0], ids[0])
assert abs(result[0][0].distance - numpy.inner(numpy.array(query_vec[0]), numpy.array(query_vec[0]))) <= gen_inaccuracy(result[0][0].distance)
status, result = connect.search_vectors(partition_name, top_k, nprobe, query_vec)
logging.getLogger().info(result)
assert status.OK()
assert len(result[0]) == min(len(vectors), top_k)
assert check_result(result[0], ids[0])
@pytest.mark.level(2)
def test_search_vectors_without_connect(self, dis_connect, table):
'''
target: test search vectors without connection
method: use dis connected instance, call search method and check if search successfully
expected: raise exception
'''
query_vectors = [vectors[0]]
top_k = 1
nprobe = 1
with pytest.raises(Exception) as e:
status, ids = dis_connect.search_vectors(table, top_k, nprobe, query_vectors)
def test_search_table_name_not_existed(self, connect, table):
'''
target: search table not existed
method: search with the random table_name, which is not in db
expected: status not ok
'''
table_name = gen_unique_str("not_existed_table")
top_k = 1
nprobe = 1
query_vecs = [vectors[0]]
status, result = connect.search_vectors(table_name, top_k, nprobe, query_vecs)
assert not status.OK()
def test_search_table_name_None(self, connect, table):
'''
target: search table that table name is None
method: search with the table_name: None
expected: status not ok
'''
table_name = None
top_k = 1
nprobe = 1
query_vecs = [vectors[0]]
with pytest.raises(Exception) as e:
status, result = connect.search_vectors(table_name, top_k, nprobe, query_vecs)
def test_search_top_k_query_records(self, connect, table):
'''
target: test search fuction, with search params: query_records
method: search with the given query_records, which are subarrays of the inserted vectors
expected: status ok and the returned vectors should be query_records
'''
top_k = 10
nprobe = 1
vectors, ids = self.init_data(connect, table)
query_vecs = [vectors[0],vectors[55],vectors[99]]
status, result = connect.search_vectors(table, top_k, nprobe, query_vecs)
assert status.OK()
assert len(result) == len(query_vecs)
for i in range(len(query_vecs)):
assert len(result[i]) == top_k
assert result[i][0].distance <= epsilon
"""
generate invalid query range params
"""
@pytest.fixture(
scope="function",
params=[
(get_current_day(), get_current_day()),
(get_last_day(1), get_last_day(1)),
(get_next_day(1), get_next_day(1))
]
)
def get_invalid_range(self, request):
yield request.param
# disable
def _test_search_invalid_query_ranges(self, connect, table, get_invalid_range):
'''
target: search table with query ranges
method: search with the same query ranges
expected: status not ok
'''
top_k = 2
nprobe = 1
vectors, ids = self.init_data(connect, table)
query_vecs = [vectors[0]]
query_ranges = [get_invalid_range]
status, result = connect.search_vectors(table, top_k, nprobe, query_vecs, query_ranges=query_ranges)
assert not status.OK()
assert len(result) == 0
"""
generate valid query range params, no search result
"""
@pytest.fixture(
scope="function",
params=[
(get_last_day(2), get_last_day(1)),
(get_next_day(1), get_next_day(2))
]
)
def get_valid_range_no_result(self, request):
yield request.param
# disable
def _test_search_valid_query_ranges_no_result(self, connect, table, get_valid_range_no_result):
'''
target: search table with normal query ranges, but no data in db
method: search with query ranges (low, low)
expected: length of result is 0
'''
top_k = 2
nprobe = 1
vectors, ids = self.init_data(connect, table)
query_vecs = [vectors[0]]
query_ranges = [get_valid_range_no_result]
status, result = connect.search_vectors(table, top_k, nprobe, query_vecs, query_ranges=query_ranges)
assert status.OK()
assert len(result) == 0
"""
generate valid query range params, no search result
"""
@pytest.fixture(
scope="function",
params=[
(get_last_day(2), get_next_day(2)),
(get_current_day(), get_next_day(2)),
]
)
def get_valid_range(self, request):
yield request.param
# disable
def _test_search_valid_query_ranges(self, connect, table, get_valid_range):
'''
target: search table with normal query ranges, but no data in db
method: search with query ranges (low, normal)
expected: length of result is 0
'''
top_k = 2
nprobe = 1
vectors, ids = self.init_data(connect, table)
query_vecs = [vectors[0]]
query_ranges = [get_valid_range]
status, result = connect.search_vectors(table, top_k, nprobe, query_vecs, query_ranges=query_ranges)
assert status.OK()
assert len(result) == 1
assert result[0][0].distance <= epsilon
def test_search_distance_l2_flat_index(self, connect, table):
'''
target: search table, and check the result: distance
method: compare the return distance value with value computed with Euclidean
expected: the return distance equals to the computed value
'''
nb = 2
top_k = 1
nprobe = 1
vectors, ids = self.init_data(connect, table, nb=nb)
query_vecs = [[0.50 for i in range(dim)]]
distance_0 = numpy.linalg.norm(numpy.array(query_vecs[0]) - numpy.array(vectors[0]))
distance_1 = numpy.linalg.norm(numpy.array(query_vecs[0]) - numpy.array(vectors[1]))
status, result = connect.search_vectors(table, top_k, nprobe, query_vecs)
assert abs(numpy.sqrt(result[0][0].distance) - min(distance_0, distance_1)) <= gen_inaccuracy(result[0][0].distance)
def test_search_distance_ip_flat_index(self, connect, ip_table):
'''
target: search ip_table, and check the result: distance
method: compare the return distance value with value computed with Inner product
expected: the return distance equals to the computed value
'''
nb = 2
top_k = 1
nprobe = 1
vectors, ids = self.init_data(connect, ip_table, nb=nb)
index_params = {
"index_type": IndexType.FLAT,
"nlist": 16384
}
connect.create_index(ip_table, index_params)
logging.getLogger().info(connect.describe_index(ip_table))
query_vecs = [[0.50 for i in range(dim)]]
distance_0 = numpy.inner(numpy.array(query_vecs[0]), numpy.array(vectors[0]))
distance_1 = numpy.inner(numpy.array(query_vecs[0]), numpy.array(vectors[1]))
status, result = connect.search_vectors(ip_table, top_k, nprobe, query_vecs)
assert abs(result[0][0].distance - max(distance_0, distance_1)) <= gen_inaccuracy(result[0][0].distance)
def test_search_distance_jaccard_flat_index(self, connect, jac_table):
'''
target: search ip_table, and check the result: distance
method: compare the return distance value with value computed with Inner product
expected: the return distance equals to the computed value
'''
# from scipy.spatial import distance
top_k = 1
nprobe = 512
int_vectors, vectors, ids = self.init_binary_data(connect, jac_table, nb=2)
index_params = {
"index_type": IndexType.FLAT,
"nlist": 16384
}
connect.create_index(jac_table, index_params)
logging.getLogger().info(connect.describe_table(jac_table))
logging.getLogger().info(connect.describe_index(jac_table))
query_int_vectors, query_vecs, tmp_ids = self.init_binary_data(connect, jac_table, nb=1, insert=False)
distance_0 = jaccard(query_int_vectors[0], int_vectors[0])
distance_1 = jaccard(query_int_vectors[0], int_vectors[1])
status, result = connect.search_vectors(jac_table, top_k, nprobe, query_vecs)
logging.getLogger().info(status)
logging.getLogger().info(result)
assert abs(result[0][0].distance - min(distance_0, distance_1)) <= epsilon
def test_search_distance_hamming_flat_index(self, connect, ham_table):
'''
target: search ip_table, and check the result: distance
method: compare the return distance value with value computed with Inner product
expected: the return distance equals to the computed value
'''
# from scipy.spatial import distance
top_k = 1
nprobe = 512
int_vectors, vectors, ids = self.init_binary_data(connect, ham_table, nb=2)
index_params = {
"index_type": IndexType.FLAT,
"nlist": 16384
}
connect.create_index(ham_table, index_params)
logging.getLogger().info(connect.describe_table(ham_table))
logging.getLogger().info(connect.describe_index(ham_table))
query_int_vectors, query_vecs, tmp_ids = self.init_binary_data(connect, ham_table, nb=1, insert=False)
distance_0 = hamming(query_int_vectors[0], int_vectors[0])
distance_1 = hamming(query_int_vectors[0], int_vectors[1])
status, result = connect.search_vectors(ham_table, top_k, nprobe, query_vecs)
logging.getLogger().info(status)
logging.getLogger().info(result)
assert abs(result[0][0].distance - min(distance_0, distance_1).astype(float)) <= epsilon
def test_search_distance_tanimoto_flat_index(self, connect, tanimoto_table):
'''
target: search ip_table, and check the result: distance
method: compare the return distance value with value computed with Inner product
expected: the return distance equals to the computed value
'''
# from scipy.spatial import distance
top_k = 1
nprobe = 512
int_vectors, vectors, ids = self.init_binary_data(connect, tanimoto_table, nb=2)
index_params = {
"index_type": IndexType.FLAT,
"nlist": 16384
}
connect.create_index(tanimoto_table, index_params)
logging.getLogger().info(connect.describe_table(tanimoto_table))
logging.getLogger().info(connect.describe_index(tanimoto_table))
query_int_vectors, query_vecs, tmp_ids = self.init_binary_data(connect, tanimoto_table, nb=1, insert=False)
distance_0 = tanimoto(query_int_vectors[0], int_vectors[0])
distance_1 = tanimoto(query_int_vectors[0], int_vectors[1])
status, result = connect.search_vectors(tanimoto_table, top_k, nprobe, query_vecs)
logging.getLogger().info(status)
logging.getLogger().info(result)
assert abs(result[0][0].distance - min(distance_0, distance_1)) <= epsilon
def test_search_distance_ip_index_params(self, connect, ip_table, get_index_params):
'''
target: search table, and check the result: distance
method: compare the return distance value with value computed with Inner product
expected: the return distance equals to the computed value
'''
top_k = 2
nprobe = 1
vectors, ids = self.init_data(connect, ip_table, nb=2)
index_params = get_index_params
connect.create_index(ip_table, index_params)
logging.getLogger().info(connect.describe_index(ip_table))
query_vecs = [[0.50 for i in range(dim)]]
status, result = connect.search_vectors(ip_table, top_k, nprobe, query_vecs)
logging.getLogger().debug(status)
logging.getLogger().debug(result)
distance_0 = numpy.inner(numpy.array(query_vecs[0]), numpy.array(vectors[0]))
distance_1 = numpy.inner(numpy.array(query_vecs[0]), numpy.array(vectors[1]))
assert abs(result[0][0].distance - max(distance_0, distance_1)) <= gen_inaccuracy(result[0][0].distance)
# TODO: enable
# @pytest.mark.repeat(5)
@pytest.mark.timeout(30)
def _test_search_concurrent(self, connect, table):
vectors, ids = self.init_data(connect, table)
thread_num = 10
nb = 100
top_k = 10
threads = []
query_vecs = vectors[nb//2:nb]
def search():
status, result = connect.search_vectors(table, top_k, query_vecs)
assert len(result) == len(query_vecs)
for i in range(len(query_vecs)):
assert result[i][0].id in ids
assert result[i][0].distance == 0.0
for i in range(thread_num):
x = threading.Thread(target=search, args=())
threads.append(x)
x.start()
for th in threads:
th.join()
# TODO: enable
@pytest.mark.timeout(30)
def _test_search_concurrent_multiprocessing(self, args):
'''
target: test concurrent search with multiprocessess
method: search with 10 processes, each process uses dependent connection
expected: status ok and the returned vectors should be query_records
'''
nb = 100
top_k = 10
process_num = 4
processes = []
table = gen_unique_str("test_search_concurrent_multiprocessing")
uri = "tcp://%s:%s" % (args["ip"], args["port"])
param = {'table_name': table,
'dimension': dim,
'index_type': IndexType.FLAT,
'store_raw_vector': False}
# create table
milvus = get_milvus(args["handler"])
milvus.connect(uri=uri)
milvus.create_table(param)
vectors, ids = self.init_data(milvus, table, nb=nb)
query_vecs = vectors[nb//2:nb]
def search(milvus):
status, result = milvus.search_vectors(table, top_k, query_vecs)
assert len(result) == len(query_vecs)
for i in range(len(query_vecs)):
assert result[i][0].id in ids
assert result[i][0].distance == 0.0
for i in range(process_num):
milvus = get_milvus(args["handler"])
milvus.connect(uri=uri)
p = Process(target=search, args=(milvus, ))
processes.append(p)
p.start()
time.sleep(0.2)
for p in processes:
p.join()
def test_search_multi_table_L2(search, args):
'''
target: test search multi tables of L2
method: add vectors into 10 tables, and search
expected: search status ok, the length of result
'''
num = 10
top_k = 10
nprobe = 1
tables = []
idx = []
for i in range(num):
table = gen_unique_str("test_add_multitable_%d" % i)
uri = "tcp://%s:%s" % (args["ip"], args["port"])
param = {'table_name': table,
'dimension': dim,
'index_file_size': 10,
'metric_type': MetricType.L2}
# create table
milvus = get_milvus(args["handler"])
milvus.connect(uri=uri)
milvus.create_table(param)
status, ids = milvus.add_vectors(table, vectors)
assert status.OK()
assert len(ids) == len(vectors)
tables.append(table)
idx.append(ids[0])
idx.append(ids[10])
idx.append(ids[20])
time.sleep(6)
query_vecs = [vectors[0], vectors[10], vectors[20]]
# start query from random table
for i in range(num):
table = tables[i]
status, result = milvus.search_vectors(table, top_k, nprobe, query_vecs)
assert status.OK()
assert len(result) == len(query_vecs)
for j in range(len(query_vecs)):
assert len(result[j]) == top_k
for j in range(len(query_vecs)):
assert check_result(result[j], idx[3 * i + j])
def test_search_multi_table_IP(search, args):
'''
target: test search multi tables of IP
method: add vectors into 10 tables, and search
expected: search status ok, the length of result
'''
num = 10
top_k = 10
nprobe = 1
tables = []
idx = []
for i in range(num):
table = gen_unique_str("test_add_multitable_%d" % i)
uri = "tcp://%s:%s" % (args["ip"], args["port"])
param = {'table_name': table,
'dimension': dim,
'index_file_size': 10,
'metric_type': MetricType.L2}
# create table
milvus = get_milvus(args["handler"])
milvus.connect(uri=uri)
milvus.create_table(param)
status, ids = milvus.add_vectors(table, vectors)
assert status.OK()
assert len(ids) == len(vectors)
tables.append(table)
idx.append(ids[0])
idx.append(ids[10])
idx.append(ids[20])
time.sleep(6)
query_vecs = [vectors[0], vectors[10], vectors[20]]
# start query from random table
for i in range(num):
table = tables[i]
status, result = milvus.search_vectors(table, top_k, nprobe, query_vecs)
assert status.OK()
assert len(result) == len(query_vecs)
for j in range(len(query_vecs)):
assert len(result[j]) == top_k
for j in range(len(query_vecs)):
assert check_result(result[j], idx[3 * i + j])
"""
******************************************************************
# The following cases are used to test `search_vectors` function
# with invalid table_name top-k / nprobe / query_range
******************************************************************
"""
class TestSearchParamsInvalid(object):
nlist = 16384
index_param = {"index_type": IndexType.IVF_SQ8, "nlist": nlist}
logging.getLogger().info(index_param)
def init_data(self, connect, table, nb=6000):
'''
Generate vectors and add it in table, before search vectors
'''
global vectors
if nb == 6000:
add_vectors = vectors
else:
add_vectors = gen_vectors(nb, dim)
status, ids = connect.add_vectors(table, add_vectors)
sleep(add_interval_time)
return add_vectors, ids
"""
Test search table with invalid table names
"""
@pytest.fixture(
scope="function",
params=gen_invalid_table_names()
)
def get_table_name(self, request):
yield request.param
@pytest.mark.level(2)
def test_search_with_invalid_tablename(self, connect, get_table_name):
table_name = get_table_name
logging.getLogger().info(table_name)
top_k = 1
nprobe = 1
query_vecs = gen_vectors(1, dim)
status, result = connect.search_vectors(table_name, top_k, nprobe, query_vecs)
assert not status.OK()
@pytest.mark.level(1)
def test_search_with_invalid_tag_format(self, connect, table):
top_k = 1
nprobe = 1
query_vecs = gen_vectors(1, dim)
with pytest.raises(Exception) as e:
status, result = connect.search_vectors(table_name, top_k, nprobe, query_vecs, partition_tags="tag")
"""
Test search table with invalid top-k
"""
@pytest.fixture(
scope="function",
params=gen_invalid_top_ks()
)
def get_top_k(self, request):
yield request.param
@pytest.mark.level(1)
def test_search_with_invalid_top_k(self, connect, table, get_top_k):
'''
target: test search fuction, with the wrong top_k
method: search with top_k
expected: raise an error, and the connection is normal
'''
top_k = get_top_k
logging.getLogger().info(top_k)
nprobe = 1
query_vecs = gen_vectors(1, dim)
if isinstance(top_k, int):
status, result = connect.search_vectors(table, top_k, nprobe, query_vecs)
assert not status.OK()
else:
with pytest.raises(Exception) as e:
status, result = connect.search_vectors(table, top_k, nprobe, query_vecs)
@pytest.mark.level(2)
def test_search_with_invalid_top_k_ip(self, connect, ip_table, get_top_k):
'''
target: test search fuction, with the wrong top_k
method: search with top_k
expected: raise an error, and the connection is normal
'''
top_k = get_top_k
logging.getLogger().info(top_k)
nprobe = 1
query_vecs = gen_vectors(1, dim)
if isinstance(top_k, int):
status, result = connect.search_vectors(ip_table, top_k, nprobe, query_vecs)
assert not status.OK()
else:
with pytest.raises(Exception) as e:
status, result = connect.search_vectors(ip_table, top_k, nprobe, query_vecs)
"""
Test search table with invalid nprobe
"""
@pytest.fixture(
scope="function",
params=gen_invalid_nprobes()
)
def get_nprobes(self, request):
yield request.param
@pytest.mark.level(1)
def test_search_with_invalid_nprobe(self, connect, table, get_nprobes):
'''
target: test search fuction, with the wrong top_k
method: search with top_k
expected: raise an error, and the connection is normal
'''
top_k = 1
nprobe = get_nprobes
logging.getLogger().info(nprobe)
query_vecs = gen_vectors(1, dim)
if isinstance(nprobe, int):
status, result = connect.search_vectors(table, top_k, nprobe, query_vecs)
assert not status.OK()
else:
with pytest.raises(Exception) as e:
status, result = connect.search_vectors(table, top_k, nprobe, query_vecs)
@pytest.mark.level(2)
def test_search_with_invalid_nprobe_ip(self, connect, ip_table, get_nprobes):
'''
target: test search fuction, with the wrong top_k
method: search with top_k
expected: raise an error, and the connection is normal
'''
top_k = 1
nprobe = get_nprobes
logging.getLogger().info(nprobe)
query_vecs = gen_vectors(1, dim)
if isinstance(nprobe, int):
status, result = connect.search_vectors(ip_table, top_k, nprobe, query_vecs)
assert not status.OK()
else:
with pytest.raises(Exception) as e:
status, result = connect.search_vectors(ip_table, top_k, nprobe, query_vecs)
"""
Test search table with invalid query ranges
"""
@pytest.fixture(
scope="function",
params=gen_invalid_query_ranges()
)
def get_query_ranges(self, request):
yield request.param
# disable
@pytest.mark.level(1)
def _test_search_flat_with_invalid_query_range(self, connect, table, get_query_ranges):
'''
target: test search fuction, with the wrong query_range
method: search with query_range
expected: raise an error, and the connection is normal
'''
top_k = 1
nprobe = 1
query_vecs = [vectors[0]]
query_ranges = get_query_ranges
logging.getLogger().info(query_ranges)
with pytest.raises(Exception) as e:
status, result = connect.search_vectors(table, 1, nprobe, query_vecs, query_ranges=query_ranges)
# disable
@pytest.mark.level(2)
def _test_search_flat_with_invalid_query_range_ip(self, connect, ip_table, get_query_ranges):
'''
target: test search fuction, with the wrong query_range
method: search with query_range
expected: raise an error, and the connection is normal
'''
top_k = 1
nprobe = 1
query_vecs = [vectors[0]]
query_ranges = get_query_ranges
logging.getLogger().info(query_ranges)
with pytest.raises(Exception) as e:
status, result = connect.search_vectors(ip_table, 1, nprobe, query_vecs, query_ranges=query_ranges)
def check_result(result, id):
if len(result) >= 5:
return id in [result[0].id, result[1].id, result[2].id, result[3].id, result[4].id]
else:
return id in (i.id for i in result)