test: add coo format sparse vector in restful test (#33689)

pr: https://github.com/milvus-io/milvus/pull/33677

Signed-off-by: zhuwenxing <wenxing.zhu@zilliz.com>
This commit is contained in:
zhuwenxing 2024-06-07 09:19:58 +08:00 committed by GitHub
parent 3562ef83b2
commit b78d7edca6
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 35 additions and 6 deletions

View File

@ -838,8 +838,9 @@ class TestSearchVector(TestBase):
@pytest.mark.parametrize("nb", [3000])
@pytest.mark.parametrize("dim", [128])
@pytest.mark.parametrize("groupingField", ['user_id', None])
@pytest.mark.parametrize("sparse_format", ['dok', 'coo'])
def test_search_vector_with_sparse_float_vector_datatype(self, nb, dim, insert_round, auto_id,
is_partition_key, enable_dynamic_schema, groupingField):
is_partition_key, enable_dynamic_schema, groupingField, sparse_format):
"""
Insert a vector with a simple payload
"""
@ -879,7 +880,7 @@ class TestSearchVector(TestBase):
"user_id": idx%100,
"word_count": j,
"book_describe": f"book_{idx}",
"sparse_float_vector": gen_vector(datatype="SparseFloatVector", dim=dim),
"sparse_float_vector": gen_vector(datatype="SparseFloatVector", dim=dim, sparse_format=sparse_format),
}
else:
tmp = {
@ -887,7 +888,7 @@ class TestSearchVector(TestBase):
"user_id": idx%100,
"word_count": j,
"book_describe": f"book_{idx}",
"sparse_float_vector": gen_vector(datatype="SparseFloatVector", dim=dim),
"sparse_float_vector": gen_vector(datatype="SparseFloatVector", dim=dim, sparse_format=sparse_format),
}
if enable_dynamic_schema:
tmp.update({f"dynamic_field_{i}": i})
@ -902,7 +903,7 @@ class TestSearchVector(TestBase):
# search data
payload = {
"collectionName": name,
"data": [gen_vector(datatype="SparseFloatVector", dim=dim)],
"data": [gen_vector(datatype="SparseFloatVector", dim=dim, sparse_format="dok")],
"filter": "word_count > 100",
"outputFields": ["*"],
"searchParams": {
@ -918,6 +919,24 @@ class TestSearchVector(TestBase):
rsp = self.vector_client.vector_search(payload)
assert rsp['code'] == 0
# search data
payload = {
"collectionName": name,
"data": [gen_vector(datatype="SparseFloatVector", dim=dim, sparse_format="coo")],
"filter": "word_count > 100",
"outputFields": ["*"],
"searchParams": {
"metricType": "IP",
"params": {
"drop_ratio_search": "0.2",
}
},
"limit": 500,
}
if groupingField:
payload["groupingField"] = groupingField
rsp = self.vector_client.vector_search(payload)
assert rsp['code'] == 0
@pytest.mark.parametrize("insert_round", [2])
@pytest.mark.parametrize("auto_id", [True])

View File

@ -197,12 +197,22 @@ def gen_bf16_vectors(num, dim):
return raw_vectors, bf16_vectors
def gen_vector(datatype="float_vector", dim=128, binary_data=False):
def gen_vector(datatype="float_vector", dim=128, binary_data=False, sparse_format='dok'):
value = None
if datatype == "FloatVector":
return preprocessing.normalize([np.array([random.random() for i in range(dim)])])[0].tolist()
if datatype == "SparseFloatVector":
return {d: rng.random() for d in random.sample(range(dim), random.randint(20, 30))}
if sparse_format == 'dok':
return {d: rng.random() for d in random.sample(range(dim), random.randint(20, 30))}
elif sparse_format == 'coo':
data = {d: rng.random() for d in random.sample(range(dim), random.randint(20, 30))}
coo_data = {
"indices": list(data.keys()),
"values": list(data.values())
}
return coo_data
else:
raise Exception(f"unsupported sparse format: {sparse_format}")
if datatype == "BinaryVector":
value = gen_binary_vectors(1, dim)[1][0]
if datatype == "Float16Vector":