milvus/tests/python_client/scale/scale_common.py
binbin eff75c7701
Replace sdk source and merge tests and tests20 (#7182)
Signed-off-by: Binbin Lv <binbin.lv@zilliz.com>
2021-08-20 11:00:56 +08:00

58 lines
2.1 KiB
Python

import os
from pymilvus import connections, Index
from scale import constants
from utils.util_log import test_log as log
from base.collection_wrapper import ApiCollectionWrapper
from common import common_func as cf
from common import common_type as ct
def get_milvus_chart_env_var(var=constants.MILVUS_CHART_ENV):
""" get log path for testing """
try:
milvus_helm_chart = os.environ[var]
return str(milvus_helm_chart)
except Exception as e:
milvus_helm_chart = constants.MILVUS_CHART_PATH
log.warning(
f"Failed to get environment variables: {var}, use default: {constants.MILVUS_CHART_PATH}, {str(e)}")
if not os.path.exists(milvus_helm_chart):
raise Exception(f'milvus_helm_chart: {milvus_helm_chart} not exist')
return milvus_helm_chart
def e2e_milvus(host, c_name):
# connect
connections.add_connection(default={"host": host, "port": 19530})
connections.connect(alias='default')
# create
# c_name = cf.gen_unique_str(prefix)
collection_w = ApiCollectionWrapper()
collection_w.init_collection(name=c_name, schema=cf.gen_default_collection_schema())
# collection_w.init_collection(name=c_name)
# insert
data = cf.gen_default_list_data(ct.default_nb)
mutation_res, _ = collection_w.insert(data)
assert mutation_res.insert_count == ct.default_nb
# create index
collection_w.create_index(ct.default_float_vec_field_name, ct.default_index)
assert collection_w.has_index()[0]
assert collection_w.index()[0] == Index(collection_w.collection, ct.default_float_vec_field_name,
ct.default_index)
# search
collection_w.load()
search_res, _ = collection_w.search(data[-1][:ct.default_nq], ct.default_float_vec_field_name,
ct.default_search_params, ct.default_limit)
assert len(search_res[0]) == ct.default_limit
# query
ids = search_res[0].ids[0]
term_expr = f'{ct.default_int64_field_name} in [{ids}]'
query_res, _ = collection_w.query(term_expr, output_fields=["*", "%"])
assert query_res[0][ct.default_int64_field_name] == ids