mirror of
https://gitee.com/milvus-io/milvus.git
synced 2026-02-02 01:06:41 +08:00
[test]Add restful api test (#25583)
Signed-off-by: zhuwenxing <wenxing.zhu@zilliz.com>
This commit is contained in:
parent
8560af6790
commit
745a550465
0
tests/restful_client/README.md
Normal file
0
tests/restful_client/README.md
Normal file
226
tests/restful_client/api/milvus.py
Normal file
226
tests/restful_client/api/milvus.py
Normal file
@ -0,0 +1,226 @@
|
||||
import json
|
||||
import requests
|
||||
import time
|
||||
import uuid
|
||||
from utils.util_log import test_log as logger
|
||||
|
||||
|
||||
def logger_request_response(response, url, tt, headers, data, str_data, str_response, method):
|
||||
if len(data) > 2000:
|
||||
data = data[:1000] + "..." + data[-1000:]
|
||||
try:
|
||||
if response.status_code == 200:
|
||||
if ('code' in response.json() and response.json()["code"] == 200) or ('Code' in response.json() and response.json()["Code"] == 0):
|
||||
logger.debug(
|
||||
f"method: {method}, url: {url}, cost time: {tt}, header: {headers}, payload: {str_data}, response: {str_response}")
|
||||
else:
|
||||
logger.error(
|
||||
f"method: {method}, url: {url}, cost time: {tt}, header: {headers}, payload: {data}, response: {response.text}")
|
||||
else:
|
||||
logger.error(
|
||||
f"method: {method}, url: {url}, cost time: {tt}, header: {headers}, payload: {data}, response: {response.text}")
|
||||
except Exception as e:
|
||||
logger.error(e)
|
||||
logger.error(
|
||||
f"method: {method}, url: {url}, cost time: {tt}, header: {headers}, payload: {data}, response: {response.text}")
|
||||
|
||||
|
||||
class Requests:
|
||||
def __init__(self, url=None, api_key=None):
|
||||
self.url = url
|
||||
self.api_key = api_key
|
||||
self.headers = {
|
||||
'Content-Type': 'application/json',
|
||||
'Authorization': f'Bearer {self.api_key}',
|
||||
'RequestId': str(uuid.uuid1())
|
||||
}
|
||||
|
||||
def update_headers(self):
|
||||
headers = {
|
||||
'Content-Type': 'application/json',
|
||||
'Authorization': f'Bearer {self.api_key}',
|
||||
'RequestId': str(uuid.uuid1())
|
||||
}
|
||||
return headers
|
||||
|
||||
def post(self, url, headers=None, data=None):
|
||||
headers = headers if headers is not None else self.update_headers()
|
||||
data = json.dumps(data)
|
||||
str_data = data[:200] + '...' + data[-200:] if len(data) > 400 else data
|
||||
t0 = time.time()
|
||||
response = requests.post(url, headers=headers, data=data)
|
||||
tt = time.time() - t0
|
||||
str_response = response.text[:200] + '...' + response.text[-200:] if len(response.text) > 400 else response.text
|
||||
logger_request_response(response, url, tt, headers, data, str_data, str_response, "post")
|
||||
return response
|
||||
|
||||
def get(self, url, headers=None, params=None, data=None):
|
||||
headers = headers if headers is not None else self.update_headers()
|
||||
data = json.dumps(data)
|
||||
str_data = data[:200] + '...' + data[-200:] if len(data) > 400 else data
|
||||
t0 = time.time()
|
||||
if data is None or data == "null":
|
||||
response = requests.get(url, headers=headers, params=params)
|
||||
else:
|
||||
response = requests.get(url, headers=headers, params=params, data=data)
|
||||
tt = time.time() - t0
|
||||
str_response = response.text[:200] + '...' + response.text[-200:] if len(response.text) > 400 else response.text
|
||||
logger_request_response(response, url, tt, headers, data, str_data, str_response, "get")
|
||||
return response
|
||||
|
||||
def put(self, url, headers=None, data=None):
|
||||
headers = headers if headers is not None else self.update_headers()
|
||||
data = json.dumps(data)
|
||||
str_data = data[:200] + '...' + data[-200:] if len(data) > 400 else data
|
||||
t0 = time.time()
|
||||
response = requests.put(url, headers=headers, data=data)
|
||||
tt = time.time() - t0
|
||||
str_response = response.text[:200] + '...' + response.text[-200:] if len(response.text) > 400 else response.text
|
||||
logger_request_response(response, url, tt, headers, data, str_data, str_response, "put")
|
||||
return response
|
||||
|
||||
def delete(self, url, headers=None, data=None):
|
||||
headers = headers if headers is not None else self.update_headers()
|
||||
data = json.dumps(data)
|
||||
str_data = data[:200] + '...' + data[-200:] if len(data) > 400 else data
|
||||
t0 = time.time()
|
||||
response = requests.delete(url, headers=headers, data=data)
|
||||
tt = time.time() - t0
|
||||
str_response = response.text[:200] + '...' + response.text[-200:] if len(response.text) > 400 else response.text
|
||||
logger_request_response(response, url, tt, headers, data, str_data, str_response, "delete")
|
||||
return response
|
||||
|
||||
|
||||
class VectorClient(Requests):
|
||||
def __init__(self, url, api_key, protocol="http"):
|
||||
super().__init__(url, api_key)
|
||||
self.protocol = protocol
|
||||
self.url = url
|
||||
self.api_key = api_key
|
||||
self.db_name = None
|
||||
self.headers = self.update_headers()
|
||||
|
||||
def update_headers(self):
|
||||
headers = {
|
||||
'Content-Type': 'application/json',
|
||||
'Authorization': f'Bearer {self.api_key}',
|
||||
'RequestId': str(uuid.uuid1())
|
||||
}
|
||||
return headers
|
||||
|
||||
def vector_search(self, payload, db_name="default"):
|
||||
url = f'{self.protocol}://{self.url}/vector/search'
|
||||
if self.db_name is not None:
|
||||
payload["dbName"] = self.db_name
|
||||
if db_name != "default":
|
||||
payload["dbName"] = db_name
|
||||
response = self.post(url, headers=self.update_headers(), data=payload)
|
||||
return response.json()
|
||||
|
||||
def vector_query(self, payload, db_name="default"):
|
||||
url = f'{self.protocol}://{self.url}/vector/query'
|
||||
if self.db_name is not None:
|
||||
payload["dbName"] = self.db_name
|
||||
if db_name != "default":
|
||||
payload["dbName"] = db_name
|
||||
response = self.post(url, headers=self.update_headers(), data=payload)
|
||||
return response.json()
|
||||
|
||||
def vector_get(self, payload, db_name="default"):
|
||||
url = f'{self.protocol}://{self.url}/vector/get'
|
||||
if self.db_name is not None:
|
||||
payload["dbName"] = self.db_name
|
||||
if db_name != "default":
|
||||
payload["dbName"] = db_name
|
||||
response = self.post(url, headers=self.update_headers(), data=payload)
|
||||
return response.json()
|
||||
|
||||
def vector_delete(self, payload, db_name="default"):
|
||||
url = f'{self.protocol}://{self.url}/vector/delete'
|
||||
if self.db_name is not None:
|
||||
payload["dbName"] = self.db_name
|
||||
if db_name != "default":
|
||||
payload["dbName"] = db_name
|
||||
response = self.post(url, headers=self.update_headers(), data=payload)
|
||||
return response.json()
|
||||
|
||||
def vector_insert(self, payload, db_name="default"):
|
||||
url = f'{self.protocol}://{self.url}/vector/insert'
|
||||
if self.db_name is not None:
|
||||
payload["dbName"] = self.db_name
|
||||
if db_name != "default":
|
||||
payload["dbName"] = db_name
|
||||
response = self.post(url, headers=self.update_headers(), data=payload)
|
||||
return response.json()
|
||||
|
||||
|
||||
class CollectionClient(Requests):
|
||||
|
||||
def __init__(self, url, api_key, protocol="http"):
|
||||
super().__init__(url, api_key)
|
||||
self.protocol = protocol
|
||||
self.url = url
|
||||
self.api_key = api_key
|
||||
self.db_name = None
|
||||
self.headers = self.update_headers()
|
||||
|
||||
def update_headers(self):
|
||||
headers = {
|
||||
'Content-Type': 'application/json',
|
||||
'Authorization': f'Bearer {self.api_key}',
|
||||
'RequestId': str(uuid.uuid1())
|
||||
}
|
||||
return headers
|
||||
|
||||
def collection_list(self, db_name="default"):
|
||||
url = f'{self.protocol}://{self.url}/vector/collections'
|
||||
params = {}
|
||||
if self.db_name is not None:
|
||||
params = {
|
||||
"dbName": self.db_name
|
||||
}
|
||||
if db_name != "default":
|
||||
params = {
|
||||
"dbName": db_name
|
||||
}
|
||||
response = self.get(url, headers=self.update_headers(), params=params)
|
||||
res = response.json()
|
||||
if res["data"] is None:
|
||||
res["data"] = []
|
||||
return res
|
||||
|
||||
def collection_create(self, payload, db_name="default"):
|
||||
time.sleep(1) # wait for collection created and in case of rate limit
|
||||
url = f'{self.protocol}://{self.url}/vector/collections/create'
|
||||
if self.db_name is not None:
|
||||
payload["dbName"] = self.db_name
|
||||
if db_name != "default":
|
||||
payload["dbName"] = db_name
|
||||
response = self.post(url, headers=self.update_headers(), data=payload)
|
||||
return response.json()
|
||||
|
||||
def collection_describe(self, collection_name, db_name="default"):
|
||||
url = f'{self.protocol}://{self.url}/vector/collections/describe'
|
||||
params = {"collectionName": collection_name}
|
||||
if self.db_name is not None:
|
||||
params = {
|
||||
"collectionName": collection_name,
|
||||
"dbName": self.db_name
|
||||
}
|
||||
if db_name != "default":
|
||||
params = {
|
||||
"collectionName": collection_name,
|
||||
"dbName": db_name
|
||||
}
|
||||
response = self.get(url, headers=self.update_headers(), params=params)
|
||||
return response.json()
|
||||
|
||||
def collection_drop(self, payload, db_name="default"):
|
||||
time.sleep(1) # wait for collection drop and in case of rate limit
|
||||
url = f'{self.protocol}://{self.url}/vector/collections/drop'
|
||||
if self.db_name is not None:
|
||||
payload["dbName"] = self.db_name
|
||||
if db_name != "default":
|
||||
payload["dbName"] = db_name
|
||||
response = self.post(url, headers=self.update_headers(), data=payload)
|
||||
return response.json()
|
||||
41
tests/restful_client/base/error_code_message.py
Normal file
41
tests/restful_client/base/error_code_message.py
Normal file
@ -0,0 +1,41 @@
|
||||
from enum import Enum
|
||||
|
||||
|
||||
class BaseError(Enum):
|
||||
pass
|
||||
|
||||
|
||||
class VectorInsertError(BaseError):
|
||||
pass
|
||||
|
||||
|
||||
class VectorSearchError(BaseError):
|
||||
pass
|
||||
|
||||
|
||||
class VectorGetError(BaseError):
|
||||
pass
|
||||
|
||||
|
||||
class VectorQueryError(BaseError):
|
||||
pass
|
||||
|
||||
|
||||
class VectorDeleteError(BaseError):
|
||||
pass
|
||||
|
||||
|
||||
class CollectionListError(BaseError):
|
||||
pass
|
||||
|
||||
|
||||
class CollectionCreateError(BaseError):
|
||||
pass
|
||||
|
||||
|
||||
class CollectionDropError(BaseError):
|
||||
pass
|
||||
|
||||
|
||||
class CollectionDescribeError(BaseError):
|
||||
pass
|
||||
113
tests/restful_client/base/testbase.py
Normal file
113
tests/restful_client/base/testbase.py
Normal file
@ -0,0 +1,113 @@
|
||||
import json
|
||||
import sys
|
||||
|
||||
import pytest
|
||||
import time
|
||||
from pymilvus import connections, db
|
||||
from utils.util_log import test_log as logger
|
||||
from api.milvus import VectorClient, CollectionClient
|
||||
from utils.utils import get_data_by_payload
|
||||
|
||||
|
||||
def get_config():
|
||||
pass
|
||||
|
||||
|
||||
class Base:
|
||||
host = None
|
||||
port = None
|
||||
url = None
|
||||
api_key = None
|
||||
username = None
|
||||
password = None
|
||||
invalid_api_key = None
|
||||
vector_client = None
|
||||
collection_client = None
|
||||
|
||||
|
||||
class TestBase(Base):
|
||||
|
||||
@pytest.fixture(scope="function", autouse=True)
|
||||
def init_client(self, host, port, username, password):
|
||||
self.host = host
|
||||
self.port = port
|
||||
self.url = f"{host}:{port}/v1"
|
||||
self.username = username
|
||||
self.password = password
|
||||
self.api_key = f"{self.username}:{self.password}"
|
||||
self.invalid_api_key = "invalid_token"
|
||||
self.vector_client = VectorClient(self.url, self.api_key)
|
||||
self.collection_client = CollectionClient(self.url, self.api_key)
|
||||
|
||||
def init_collection(self, collection_name, pk_field="id", metric_type="L2", dim=128, nb=100):
|
||||
# drop all collections
|
||||
try:
|
||||
all_collections = self.collection_client.collection_list()['data']
|
||||
except Exception as e:
|
||||
logger.error(e)
|
||||
all_collections = []
|
||||
for collection in all_collections:
|
||||
name = collection
|
||||
payload = {
|
||||
"collectionName": name,
|
||||
}
|
||||
try:
|
||||
rsp = self.collection_client.collection_drop(payload)
|
||||
except Exception as e:
|
||||
logger.error(e)
|
||||
time.sleep(1)
|
||||
# create collection
|
||||
schema_payload = {
|
||||
"collectionName": collection_name,
|
||||
"dimension": dim,
|
||||
"metricType": metric_type,
|
||||
"description": "test collection",
|
||||
"primaryField": pk_field,
|
||||
"vectorField": "vector",
|
||||
}
|
||||
rsp = self.collection_client.collection_create(schema_payload)
|
||||
assert rsp['code'] == 200
|
||||
self.wait_collection_load_completed(collection_name)
|
||||
batch_size = 1000
|
||||
batch = nb // batch_size
|
||||
# in case of nb < batch_size
|
||||
if batch == 0:
|
||||
batch = 1
|
||||
batch_size = nb
|
||||
data = []
|
||||
for i in range(batch):
|
||||
nb = batch_size
|
||||
data = get_data_by_payload(schema_payload, nb)
|
||||
payload = {
|
||||
"collectionName": collection_name,
|
||||
"data": data
|
||||
}
|
||||
body_size = sys.getsizeof(json.dumps(payload))
|
||||
logger.info(f"body size: {body_size / 1024 / 1024} MB")
|
||||
rsp = self.vector_client.vector_insert(payload)
|
||||
assert rsp['code'] == 200
|
||||
return schema_payload, data
|
||||
|
||||
def wait_collection_load_completed(self, name):
|
||||
t0 = time.time()
|
||||
timeout = 60
|
||||
while True and time.time() - t0 < timeout:
|
||||
rsp = self.collection_client.collection_describe(name)
|
||||
if "data" in rsp and "load" in rsp["data"] and rsp["data"]["load"] == "LoadStateLoaded":
|
||||
break
|
||||
else:
|
||||
time.sleep(5)
|
||||
|
||||
def create_database(self, db_name="default"):
|
||||
connections.connect(host=self.host, port=self.port)
|
||||
all_db = db.list_database()
|
||||
logger.info(f"all database: {all_db}")
|
||||
if db_name not in all_db:
|
||||
logger.info(f"create database: {db_name}")
|
||||
db.create_database(db_name=db_name)
|
||||
|
||||
def update_database(self, db_name="default"):
|
||||
self.create_database(db_name=db_name)
|
||||
self.collection_client.db_name = db_name
|
||||
self.vector_client.db_name = db_name
|
||||
|
||||
44
tests/restful_client/config/log_config.py
Normal file
44
tests/restful_client/config/log_config.py
Normal file
@ -0,0 +1,44 @@
|
||||
import os
|
||||
|
||||
|
||||
class LogConfig:
|
||||
def __init__(self):
|
||||
self.log_debug = ""
|
||||
self.log_err = ""
|
||||
self.log_info = ""
|
||||
self.log_worker = ""
|
||||
self.get_default_config()
|
||||
|
||||
@staticmethod
|
||||
def get_env_variable(var="CI_LOG_PATH"):
|
||||
""" get log path for testing """
|
||||
try:
|
||||
log_path = os.environ[var]
|
||||
return str(log_path)
|
||||
except Exception as e:
|
||||
# now = datetime.datetime.now().strftime("%Y-%m-%d-%H-%M-%S")
|
||||
log_path = f"/tmp/ci_logs"
|
||||
print("[get_env_variable] failed to get environment variables : %s, use default path : %s" % (str(e), log_path))
|
||||
return log_path
|
||||
|
||||
@staticmethod
|
||||
def create_path(log_path):
|
||||
if not os.path.isdir(str(log_path)):
|
||||
print("[create_path] folder(%s) is not exist." % log_path)
|
||||
print("[create_path] create path now...")
|
||||
os.makedirs(log_path)
|
||||
|
||||
def get_default_config(self):
|
||||
""" Make sure the path exists """
|
||||
log_dir = self.get_env_variable()
|
||||
self.log_debug = "%s/ci_test_log.debug" % log_dir
|
||||
self.log_info = "%s/ci_test_log.log" % log_dir
|
||||
self.log_err = "%s/ci_test_log.err" % log_dir
|
||||
work_log = os.environ.get('PYTEST_XDIST_WORKER')
|
||||
if work_log is not None:
|
||||
self.log_worker = f'{log_dir}/{work_log}.log'
|
||||
|
||||
self.create_path(log_dir)
|
||||
|
||||
|
||||
log_config = LogConfig()
|
||||
30
tests/restful_client/conftest.py
Normal file
30
tests/restful_client/conftest.py
Normal file
@ -0,0 +1,30 @@
|
||||
import pytest
|
||||
import yaml
|
||||
|
||||
|
||||
def pytest_addoption(parser):
|
||||
parser.addoption("--host", action="store", default="127.0.0.1", help="host")
|
||||
parser.addoption("--port", action="store", default="19530", help="port")
|
||||
parser.addoption("--username", action="store", default="root", help="email")
|
||||
parser.addoption("--password", action="store", default="Milvus", help="password")
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def host(request):
|
||||
return request.config.getoption("--host")
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def port(request):
|
||||
return request.config.getoption("--port")
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def username(request):
|
||||
return request.config.getoption("--username")
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def password(request):
|
||||
return request.config.getoption("--password")
|
||||
|
||||
14
tests/restful_client/pytest.ini
Normal file
14
tests/restful_client/pytest.ini
Normal file
@ -0,0 +1,14 @@
|
||||
[pytest]
|
||||
addopts = --strict --host 127.0.0.1 --port 19530 --username root --password Milvus --log-cli-level=INFO --capture=no
|
||||
|
||||
log_format = [%(asctime)s - %(levelname)s - %(name)s]: %(message)s (%(filename)s:%(lineno)s)
|
||||
log_date_format = %Y-%m-%d %H:%M:%S
|
||||
|
||||
|
||||
filterwarnings =
|
||||
ignore::DeprecationWarning
|
||||
|
||||
markers =
|
||||
L0 : 'L0 case, high priority'
|
||||
L1 : 'L1 case, second priority'
|
||||
|
||||
12
tests/restful_client/requirements.txt
Normal file
12
tests/restful_client/requirements.txt
Normal file
@ -0,0 +1,12 @@
|
||||
requests~=2.26.0
|
||||
urllib3==1.26.16
|
||||
loguru~=0.5.3
|
||||
pytest~=7.2.0
|
||||
pyyaml~=6.0
|
||||
numpy~=1.24.3
|
||||
allure-pytest>=2.8.18
|
||||
Faker~=15.3.4
|
||||
pymilvus~=2.2.9
|
||||
sklearn~=0.0
|
||||
scikit-learn~=1.1.3
|
||||
yaml~=0.2.5
|
||||
467
tests/restful_client/testcases/test_collection_operations.py
Normal file
467
tests/restful_client/testcases/test_collection_operations.py
Normal file
@ -0,0 +1,467 @@
|
||||
import datetime
|
||||
import random
|
||||
import time
|
||||
from utils.util_log import test_log as logger
|
||||
import pytest
|
||||
from api.milvus import CollectionClient
|
||||
from base.testbase import TestBase
|
||||
import threading
|
||||
|
||||
|
||||
@pytest.mark.L0
|
||||
class TestCreateCollection(TestBase):
|
||||
|
||||
def teardown_method(self):
|
||||
try:
|
||||
all_collections = self.collection_client.collection_list()['data']
|
||||
except Exception as e:
|
||||
logger.error(e)
|
||||
all_collections = []
|
||||
if all_collections is None:
|
||||
all_collections = []
|
||||
for collection in all_collections:
|
||||
name = collection
|
||||
payload = {
|
||||
"collectionName": name,
|
||||
}
|
||||
try:
|
||||
rsp = self.collection_client.collection_drop(payload)
|
||||
except Exception as e:
|
||||
logger.error(e)
|
||||
|
||||
@pytest.mark.parametrize("vector_field", [None, "vector", "emb"])
|
||||
@pytest.mark.parametrize("primary_field", [None, "id", "doc_id"])
|
||||
@pytest.mark.parametrize("metric_type", ["L2", "IP"])
|
||||
@pytest.mark.parametrize("dim", [32, 32768])
|
||||
@pytest.mark.parametrize("db_name", ["prod", "default"])
|
||||
def test_create_collections_default(self, dim, metric_type, primary_field, vector_field, db_name):
|
||||
"""
|
||||
target: test create collection
|
||||
method: create a collection with a simple schema
|
||||
expected: create collection success
|
||||
"""
|
||||
self.create_database(db_name)
|
||||
name = 'test_collection_' + datetime.datetime.now().strftime("%Y_%m_%d_%H_%M_%S")
|
||||
dim = 128
|
||||
client = self.collection_client
|
||||
client.db_name = db_name
|
||||
payload = {
|
||||
"collectionName": name,
|
||||
"dimension": dim,
|
||||
"metricType": metric_type,
|
||||
"primaryField": primary_field,
|
||||
"vectorField": vector_field,
|
||||
}
|
||||
if primary_field is None:
|
||||
del payload["primaryField"]
|
||||
if vector_field is None:
|
||||
del payload["vectorField"]
|
||||
rsp = client.collection_create(payload)
|
||||
assert rsp['code'] == 200
|
||||
rsp = client.collection_list()
|
||||
|
||||
all_collections = rsp['data']
|
||||
assert name in all_collections
|
||||
# describe collection
|
||||
rsp = client.collection_describe(name)
|
||||
assert rsp['code'] == 200
|
||||
assert rsp['data']['collectionName'] == name
|
||||
# assert f"FloatVector({dim})" in str(rsp['data']['fields'])
|
||||
|
||||
def test_create_collections_concurrent_with_same_param(self):
|
||||
"""
|
||||
target: test create collection with same param
|
||||
method: concurrent create collections with same param with multi thread
|
||||
expected: create collections all success
|
||||
"""
|
||||
concurrent_rsp = []
|
||||
|
||||
def create_collection(c_name, vector_dim, c_metric_type):
|
||||
collection_payload = {
|
||||
"collectionName": c_name,
|
||||
"dimension": vector_dim,
|
||||
"metricType": c_metric_type,
|
||||
}
|
||||
rsp = client.collection_create(collection_payload)
|
||||
concurrent_rsp.append(rsp)
|
||||
logger.info(rsp)
|
||||
name = 'test_collection_' + datetime.datetime.now().strftime("%Y_%m_%d_%H_%M_%S")
|
||||
dim = 128
|
||||
metric_type = "L2"
|
||||
client = self.collection_client
|
||||
threads = []
|
||||
for i in range(10):
|
||||
t = threading.Thread(target=create_collection, args=(name, dim, metric_type,))
|
||||
threads.append(t)
|
||||
for t in threads:
|
||||
t.start()
|
||||
for t in threads:
|
||||
t.join()
|
||||
time.sleep(10)
|
||||
success_cnt = 0
|
||||
for rsp in concurrent_rsp:
|
||||
if rsp["code"] == 200:
|
||||
success_cnt += 1
|
||||
logger.info(concurrent_rsp)
|
||||
assert success_cnt == 10
|
||||
rsp = client.collection_list()
|
||||
all_collections = rsp['data']
|
||||
assert name in all_collections
|
||||
# describe collection
|
||||
rsp = client.collection_describe(name)
|
||||
assert rsp['code'] == 200
|
||||
assert rsp['data']['collectionName'] == name
|
||||
# assert f"floatVector({dim})" in str(rsp['data']['fields'])
|
||||
|
||||
def test_create_collections_concurrent_with_different_param(self):
|
||||
"""
|
||||
target: test create collection with different param
|
||||
method: concurrent create collections with different param with multi thread
|
||||
expected: only one collection can success
|
||||
"""
|
||||
concurrent_rsp = []
|
||||
|
||||
def create_collection(c_name, vector_dim, c_metric_type):
|
||||
collection_payload = {
|
||||
"collectionName": c_name,
|
||||
"dimension": vector_dim,
|
||||
"metricType": c_metric_type,
|
||||
}
|
||||
rsp = client.collection_create(collection_payload)
|
||||
concurrent_rsp.append(rsp)
|
||||
logger.info(rsp)
|
||||
name = 'test_collection_' + datetime.datetime.now().strftime("%Y_%m_%d_%H_%M_%S")
|
||||
dim = 128
|
||||
client = self.collection_client
|
||||
threads = []
|
||||
for i in range(0, 5):
|
||||
t = threading.Thread(target=create_collection, args=(name, dim + i, "L2",))
|
||||
threads.append(t)
|
||||
for i in range(5, 10):
|
||||
t = threading.Thread(target=create_collection, args=(name, dim + i, "IP",))
|
||||
threads.append(t)
|
||||
for t in threads:
|
||||
t.start()
|
||||
for t in threads:
|
||||
t.join()
|
||||
time.sleep(10)
|
||||
success_cnt = 0
|
||||
for rsp in concurrent_rsp:
|
||||
if rsp["code"] == 200:
|
||||
success_cnt += 1
|
||||
logger.info(concurrent_rsp)
|
||||
assert success_cnt == 1
|
||||
rsp = client.collection_list()
|
||||
all_collections = rsp['data']
|
||||
assert name in all_collections
|
||||
# describe collection
|
||||
rsp = client.collection_describe(name)
|
||||
assert rsp['code'] == 200
|
||||
assert rsp['data']['collectionName'] == name
|
||||
|
||||
def test_create_collections_with_invalid_api_key(self):
|
||||
"""
|
||||
target: test create collection with invalid api key(wrong username and password)
|
||||
method: create collections with invalid api key
|
||||
expected: create collection failed
|
||||
"""
|
||||
name = 'test_collection_' + datetime.datetime.now().strftime("%Y_%m_%d_%H_%M_%S")
|
||||
dim = 128
|
||||
client = self.collection_client
|
||||
client.api_key = "illegal_api_key"
|
||||
payload = {
|
||||
"collectionName": name,
|
||||
"dimension": dim,
|
||||
}
|
||||
rsp = client.collection_create(payload)
|
||||
assert rsp['code'] == 407
|
||||
|
||||
@pytest.mark.parametrize("name", [" ", "test_collection_" * 100, "test collection", "test/collection", "test\collection"])
|
||||
def test_create_collections_with_invalid_collection_name(self, name):
|
||||
"""
|
||||
target: test create collection with invalid collection name
|
||||
method: create collections with invalid collection name
|
||||
expected: create collection failed with right error message
|
||||
"""
|
||||
dim = 128
|
||||
client = self.collection_client
|
||||
payload = {
|
||||
"collectionName": name,
|
||||
"dimension": dim,
|
||||
}
|
||||
rsp = client.collection_create(payload)
|
||||
assert rsp['code'] == 1
|
||||
|
||||
|
||||
@pytest.mark.L0
|
||||
class TestListCollections(TestBase):
|
||||
|
||||
def teardown_method(self):
|
||||
try:
|
||||
all_collections = self.collection_client.collection_list()['data']
|
||||
except Exception as e:
|
||||
logger.error(e)
|
||||
all_collections = []
|
||||
for collection in all_collections:
|
||||
name = collection
|
||||
payload = {
|
||||
"collectionName": name,
|
||||
}
|
||||
try:
|
||||
rsp = self.collection_client.collection_drop(payload)
|
||||
except Exception as e:
|
||||
logger.error(e)
|
||||
|
||||
def test_list_collections_default(self):
|
||||
"""
|
||||
target: test list collection with a simple schema
|
||||
method: create collections and list them
|
||||
expected: created collections are in list
|
||||
"""
|
||||
client = self.collection_client
|
||||
name_list = []
|
||||
for i in range(2):
|
||||
name = 'test_collection_' + datetime.datetime.now().strftime("%Y_%m_%d_%H_%M_%S")
|
||||
dim = 128
|
||||
payload = {
|
||||
"collectionName": name,
|
||||
"dimension": dim,
|
||||
}
|
||||
time.sleep(1)
|
||||
rsp = client.collection_create(payload)
|
||||
assert rsp['code'] == 200
|
||||
name_list.append(name)
|
||||
rsp = client.collection_list()
|
||||
all_collections = rsp['data']
|
||||
for name in name_list:
|
||||
assert name in all_collections
|
||||
|
||||
def test_list_collections_with_invalid_api_key(self, url):
|
||||
"""
|
||||
target: test list collection with an invalid api key
|
||||
method: list collection with invalid api key
|
||||
expected: raise error with right error code and message
|
||||
"""
|
||||
client = self.collection_client
|
||||
name_list = []
|
||||
for i in range(2):
|
||||
name = 'test_collection_' + datetime.datetime.now().strftime("%Y_%m_%d_%H_%M_%S")
|
||||
dim = 128
|
||||
payload = {
|
||||
"collectionName": name,
|
||||
"dimension": dim,
|
||||
}
|
||||
time.sleep(1)
|
||||
rsp = client.collection_create(payload)
|
||||
assert rsp['code'] == 200
|
||||
name_list.append(name)
|
||||
client = self.collection_client
|
||||
client.api_key = "illegal_api_key"
|
||||
rsp = client.collection_list()
|
||||
assert rsp['code'] == 407
|
||||
|
||||
|
||||
@pytest.mark.L0
|
||||
class TestDescribeCollection(TestBase):
|
||||
|
||||
def teardown_method(self):
|
||||
try:
|
||||
all_collections = self.collection_client.collection_list()['data']
|
||||
except Exception as e:
|
||||
logger.error(e)
|
||||
all_collections = []
|
||||
for collection in all_collections:
|
||||
name = collection
|
||||
payload = {
|
||||
"collectionName": name,
|
||||
}
|
||||
try:
|
||||
rsp = self.collection_client.collection_drop(payload)
|
||||
except Exception as e:
|
||||
logger.error(e)
|
||||
|
||||
def test_describe_collections_default(self):
|
||||
"""
|
||||
target: test describe collection with a simple schema
|
||||
method: describe collection
|
||||
expected: info of description is same with param passed to create collection
|
||||
"""
|
||||
name = 'test_collection_' + datetime.datetime.now().strftime("%Y_%m_%d_%H_%M_%S")
|
||||
dim = 128
|
||||
client = self.collection_client
|
||||
payload = {
|
||||
"collectionName": name,
|
||||
"dimension": dim,
|
||||
}
|
||||
rsp = client.collection_create(payload)
|
||||
assert rsp['code'] == 200
|
||||
rsp = client.collection_list()
|
||||
all_collections = rsp['data']
|
||||
assert name in all_collections
|
||||
# describe collection
|
||||
rsp = client.collection_describe(name)
|
||||
assert rsp['code'] == 200
|
||||
assert rsp['data']['collectionName'] == name
|
||||
# assert f"floatVector({dim})" in str(rsp['data']['fields'])
|
||||
|
||||
def test_describe_collections_with_invalid_api_key(self):
|
||||
"""
|
||||
target: test describe collection with invalid api key
|
||||
method: describe collection with invalid api key
|
||||
expected: raise error with right error code and message
|
||||
"""
|
||||
name = 'test_collection_' + datetime.datetime.now().strftime("%Y_%m_%d_%H_%M_%S")
|
||||
dim = 128
|
||||
client = self.collection_client
|
||||
payload = {
|
||||
"collectionName": name,
|
||||
"dimension": dim,
|
||||
}
|
||||
rsp = client.collection_create(payload)
|
||||
assert rsp['code'] == 200
|
||||
rsp = client.collection_list()
|
||||
all_collections = rsp['data']
|
||||
assert name in all_collections
|
||||
# describe collection
|
||||
illegal_client = CollectionClient(self.url, "illegal_api_key")
|
||||
rsp = illegal_client.collection_describe(name)
|
||||
assert rsp['code'] == 407
|
||||
|
||||
def test_describe_collections_with_invalid_collection_name(self):
|
||||
"""
|
||||
target: test describe collection with invalid collection name
|
||||
method: describe collection with invalid collection name
|
||||
expected: raise error with right error code and message
|
||||
"""
|
||||
name = 'test_collection_' + datetime.datetime.now().strftime("%Y_%m_%d_%H_%M_%S")
|
||||
dim = 128
|
||||
client = self.collection_client
|
||||
payload = {
|
||||
"collectionName": name,
|
||||
"dimension": dim,
|
||||
}
|
||||
rsp = client.collection_create(payload)
|
||||
assert rsp['code'] == 200
|
||||
rsp = client.collection_list()
|
||||
all_collections = rsp['data']
|
||||
assert name in all_collections
|
||||
# describe collection
|
||||
invalid_name = "invalid_name"
|
||||
rsp = client.collection_describe(invalid_name)
|
||||
assert rsp['code'] == 1
|
||||
|
||||
|
||||
@pytest.mark.L0
|
||||
class TestDropCollection(TestBase):
|
||||
|
||||
def teardown_method(self):
|
||||
try:
|
||||
all_collections = self.collection_client.collection_list()['data']
|
||||
except Exception as e:
|
||||
logger.error(e)
|
||||
all_collections = []
|
||||
for collection in all_collections:
|
||||
name = collection
|
||||
payload = {
|
||||
"collectionName": name,
|
||||
}
|
||||
try:
|
||||
rsp = self.collection_client.collection_drop(payload)
|
||||
except Exception as e:
|
||||
logger.error(e)
|
||||
|
||||
def test_drop_collections_default(self):
|
||||
"""
|
||||
Drop a collection with a simple schema
|
||||
target: test drop collection with a simple schema
|
||||
method: drop collection
|
||||
expected: dropped collection was not in collection list
|
||||
"""
|
||||
rsp = self.collection_client.collection_list()
|
||||
all_collections = rsp['data']
|
||||
for name in all_collections:
|
||||
time.sleep(1)
|
||||
payload = {
|
||||
"collectionName": name,
|
||||
}
|
||||
rsp = self.collection_client.collection_drop(payload)
|
||||
assert rsp['code'] == 200
|
||||
clo_list = []
|
||||
for i in range(2):
|
||||
time.sleep(1)
|
||||
name = 'test_collection_' + datetime.datetime.now().strftime("%Y_%m_%d_%H_%M_%S_%f")
|
||||
payload = {
|
||||
"collectionName": name,
|
||||
"dimension": 128,
|
||||
}
|
||||
rsp = self.collection_client.collection_create(payload)
|
||||
assert rsp['code'] == 200
|
||||
clo_list.append(name)
|
||||
rsp = self.collection_client.collection_list()
|
||||
all_collections = rsp['data']
|
||||
for name in all_collections:
|
||||
time.sleep(0.2)
|
||||
payload = {
|
||||
"collectionName": name,
|
||||
}
|
||||
rsp = self.collection_client.collection_drop(payload)
|
||||
assert rsp['code'] == 200
|
||||
rsp = self.collection_client.collection_list()
|
||||
all_collections = rsp['data']
|
||||
for name in clo_list:
|
||||
assert name not in all_collections
|
||||
|
||||
def test_drop_collections_with_invalid_api_key(self):
|
||||
"""
|
||||
target: test drop collection with invalid api key
|
||||
method: drop collection with invalid api key
|
||||
expected: raise error with right error code and message; collection still in collection list
|
||||
"""
|
||||
name = 'test_collection_' + datetime.datetime.now().strftime("%Y_%m_%d_%H_%M_%S")
|
||||
dim = 128
|
||||
client = self.collection_client
|
||||
payload = {
|
||||
"collectionName": name,
|
||||
"dimension": dim,
|
||||
}
|
||||
rsp = client.collection_create(payload)
|
||||
assert rsp['code'] == 200
|
||||
rsp = client.collection_list()
|
||||
all_collections = rsp['data']
|
||||
assert name in all_collections
|
||||
# drop collection
|
||||
payload = {
|
||||
"collectionName": name,
|
||||
}
|
||||
illegal_client = CollectionClient(self.url, "invalid_api_key")
|
||||
rsp = illegal_client.collection_drop(payload)
|
||||
assert rsp['code'] == 407
|
||||
rsp = client.collection_list()
|
||||
all_collections = rsp['data']
|
||||
assert name in all_collections
|
||||
|
||||
def test_drop_collections_with_invalid_collection_name(self):
|
||||
"""
|
||||
target: test drop collection with invalid collection name
|
||||
method: drop collection with invalid collection name
|
||||
expected: raise error with right error code and message
|
||||
"""
|
||||
name = 'test_collection_' + datetime.datetime.now().strftime("%Y_%m_%d_%H_%M_%S")
|
||||
dim = 128
|
||||
client = self.collection_client
|
||||
payload = {
|
||||
"collectionName": name,
|
||||
"dimension": dim,
|
||||
}
|
||||
rsp = client.collection_create(payload)
|
||||
assert rsp['code'] == 200
|
||||
rsp = client.collection_list()
|
||||
all_collections = rsp['data']
|
||||
assert name in all_collections
|
||||
# drop collection
|
||||
invalid_name = "invalid_name"
|
||||
payload = {
|
||||
"collectionName": invalid_name,
|
||||
}
|
||||
rsp = client.collection_drop(payload)
|
||||
assert rsp['code'] == 1
|
||||
1462
tests/restful_client/testcases/test_vector_operations.py
Normal file
1462
tests/restful_client/testcases/test_vector_operations.py
Normal file
File diff suppressed because it is too large
Load Diff
2
tests/restful_client/utils/constant.py
Normal file
2
tests/restful_client/utils/constant.py
Normal file
@ -0,0 +1,2 @@
|
||||
|
||||
MAX_SUM_OFFSET_AND_LIMIT = 16384
|
||||
60
tests/restful_client/utils/util_log.py
Normal file
60
tests/restful_client/utils/util_log.py
Normal file
@ -0,0 +1,60 @@
|
||||
import logging
|
||||
from loguru import logger as loguru_logger
|
||||
import sys
|
||||
|
||||
from config.log_config import log_config
|
||||
|
||||
|
||||
class TestLog:
|
||||
def __init__(self, logger, log_debug, log_file, log_err, log_worker):
|
||||
self.logger = logger
|
||||
self.log_debug = log_debug
|
||||
self.log_file = log_file
|
||||
self.log_err = log_err
|
||||
self.log_worker = log_worker
|
||||
|
||||
self.log = logging.getLogger(self.logger)
|
||||
self.log.setLevel(logging.DEBUG)
|
||||
|
||||
try:
|
||||
formatter = logging.Formatter("[%(asctime)s - %(levelname)s - %(name)s]: "
|
||||
"%(message)s (%(filename)s:%(lineno)s)")
|
||||
# [%(process)s] process NO.
|
||||
dh = logging.FileHandler(self.log_debug)
|
||||
dh.setLevel(logging.DEBUG)
|
||||
dh.setFormatter(formatter)
|
||||
self.log.addHandler(dh)
|
||||
|
||||
fh = logging.FileHandler(self.log_file)
|
||||
fh.setLevel(logging.INFO)
|
||||
fh.setFormatter(formatter)
|
||||
self.log.addHandler(fh)
|
||||
|
||||
eh = logging.FileHandler(self.log_err)
|
||||
eh.setLevel(logging.ERROR)
|
||||
eh.setFormatter(formatter)
|
||||
self.log.addHandler(eh)
|
||||
|
||||
if self.log_worker != "":
|
||||
wh = logging.FileHandler(self.log_worker)
|
||||
wh.setLevel(logging.DEBUG)
|
||||
wh.setFormatter(formatter)
|
||||
self.log.addHandler(wh)
|
||||
|
||||
ch = logging.StreamHandler(sys.stdout)
|
||||
ch.setLevel(logging.DEBUG)
|
||||
ch.setFormatter(formatter)
|
||||
# self.log.addHandler(ch)
|
||||
|
||||
except Exception as e:
|
||||
print("Can not use %s or %s or %s to log. error : %s" % (log_debug, log_file, log_err, str(e)))
|
||||
|
||||
|
||||
"""All modules share this unified log"""
|
||||
log_debug = log_config.log_debug
|
||||
log_info = log_config.log_info
|
||||
log_err = log_config.log_err
|
||||
log_worker = log_config.log_worker
|
||||
self_defined_log = TestLog('ci_test', log_debug, log_info, log_err, log_worker).log
|
||||
loguru_log = loguru_logger
|
||||
test_log = loguru_logger
|
||||
145
tests/restful_client/utils/utils.py
Normal file
145
tests/restful_client/utils/utils.py
Normal file
@ -0,0 +1,145 @@
|
||||
import random
|
||||
import time
|
||||
from faker import Faker
|
||||
import numpy as np
|
||||
from sklearn import preprocessing
|
||||
import requests
|
||||
from loguru import logger
|
||||
fake = Faker()
|
||||
def admin_password():
|
||||
return "Zilliz@123"
|
||||
|
||||
|
||||
def invalid_cluster_name():
|
||||
res = [
|
||||
"demo" * 100,
|
||||
"demo" + "!",
|
||||
"demo" + "@",
|
||||
]
|
||||
return res
|
||||
|
||||
|
||||
def wait_cluster_be_ready(cluster_id, client, timeout=120):
|
||||
t0 = time.time()
|
||||
while True and time.time() - t0 < timeout:
|
||||
rsp = client.cluster_describe(cluster_id)
|
||||
if rsp['code'] == 200:
|
||||
if rsp['data']['status'] == "RUNNING":
|
||||
return time.time() - t0
|
||||
time.sleep(1)
|
||||
logger.debug("wait cluster to be ready, cost time: %s" % (time.time() - t0))
|
||||
return -1
|
||||
|
||||
|
||||
def force_delete_cluster(cluster_id):
|
||||
url = f"https://cloud-test.cloud-uat3.zilliz.com/cloud/v1/test/deleteInstance?instanceId={cluster_id}"
|
||||
rsp = requests.get(url).json()
|
||||
logger.info(rsp)
|
||||
assert rsp["Code"] == 0
|
||||
assert rsp["Message"] == "success"
|
||||
|
||||
|
||||
def gen_data_by_type(field):
|
||||
data_type = field["type"]
|
||||
if data_type == "bool":
|
||||
return random.choice([True, False])
|
||||
if data_type == "int8":
|
||||
return random.randint(-128, 127)
|
||||
if data_type == "int16":
|
||||
return random.randint(-32768, 32767)
|
||||
if data_type == "int32":
|
||||
return random.randint(-2147483648, 2147483647)
|
||||
if data_type == "int64":
|
||||
return random.randint(-9223372036854775808, 9223372036854775807)
|
||||
if data_type == "float32":
|
||||
return np.float64(random.random()) # Object of type float32 is not JSON serializable, so set it as float64
|
||||
if data_type == "float64":
|
||||
return np.float64(random.random())
|
||||
if "varchar" in data_type:
|
||||
length = int(data_type.split("(")[1].split(")")[0])
|
||||
return "".join([chr(random.randint(97, 122)) for _ in range(length)])
|
||||
if "floatVector" in data_type:
|
||||
dim = int(data_type.split("(")[1].split(")")[0])
|
||||
return preprocessing.normalize([np.array([random.random() for i in range(dim)])])[0].tolist()
|
||||
return None
|
||||
|
||||
|
||||
def get_data_by_fields(fields, nb):
|
||||
# logger.info(f"fields: {fields}")
|
||||
fields_not_auto_id = []
|
||||
for field in fields:
|
||||
if not field.get("autoId", False):
|
||||
fields_not_auto_id.append(field)
|
||||
# logger.info(f"fields_not_auto_id: {fields_not_auto_id}")
|
||||
data = []
|
||||
for i in range(nb):
|
||||
tmp = {}
|
||||
for field in fields_not_auto_id:
|
||||
tmp[field["name"]] = gen_data_by_type(field)
|
||||
data.append(tmp)
|
||||
return data
|
||||
|
||||
|
||||
def get_random_json_data(uid=None):
|
||||
# gen random dict data
|
||||
if uid is None:
|
||||
uid = 0
|
||||
data = {"uid": uid, "name": fake.name(), "address": fake.address(), "text": fake.text(), "email": fake.email(),
|
||||
"phone_number": fake.phone_number(),
|
||||
"json": {
|
||||
"name": fake.name(),
|
||||
"address": fake.address()
|
||||
}
|
||||
}
|
||||
for i in range(random.randint(1, 10)):
|
||||
data["key" + str(random.randint(1, 100_000))] = "value" + str(random.randint(1, 100_000))
|
||||
return data
|
||||
|
||||
|
||||
def get_data_by_payload(payload, nb=100):
|
||||
dim = payload.get("dimension", 128)
|
||||
vector_field = payload.get("vectorField", "vector")
|
||||
data = []
|
||||
if nb == 1:
|
||||
data = [{
|
||||
vector_field: preprocessing.normalize([np.array([random.random() for i in range(dim)])])[0].tolist(),
|
||||
**get_random_json_data()
|
||||
|
||||
}]
|
||||
else:
|
||||
for i in range(nb):
|
||||
data.append({
|
||||
vector_field: preprocessing.normalize([np.array([random.random() for i in range(dim)])])[0].tolist(),
|
||||
**get_random_json_data(uid=i)
|
||||
})
|
||||
return data
|
||||
|
||||
|
||||
def get_common_fields_by_data(data, exclude_fields=None):
|
||||
fields = set()
|
||||
if isinstance(data, dict):
|
||||
data = [data]
|
||||
if not isinstance(data, list):
|
||||
raise Exception("data must be list or dict")
|
||||
common_fields = set(data[0].keys())
|
||||
for d in data:
|
||||
keys = set(d.keys())
|
||||
common_fields = common_fields.intersection(keys)
|
||||
if exclude_fields is not None:
|
||||
exclude_fields = set(exclude_fields)
|
||||
common_fields = common_fields.difference(exclude_fields)
|
||||
return list(common_fields)
|
||||
|
||||
|
||||
def get_all_fields_by_data(data, exclude_fields=None):
|
||||
fields = set()
|
||||
for d in data:
|
||||
keys = list(d.keys())
|
||||
fields.union(keys)
|
||||
if exclude_fields is not None:
|
||||
exclude_fields = set(exclude_fields)
|
||||
fields = fields.difference(exclude_fields)
|
||||
return list(fields)
|
||||
|
||||
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user