diff --git a/tests/restful_client_v2/testcases/test_vector_operations.py b/tests/restful_client_v2/testcases/test_vector_operations.py index d1310a5b19..bce1f9acca 100644 --- a/tests/restful_client_v2/testcases/test_vector_operations.py +++ b/tests/restful_client_v2/testcases/test_vector_operations.py @@ -838,8 +838,9 @@ class TestSearchVector(TestBase): @pytest.mark.parametrize("nb", [3000]) @pytest.mark.parametrize("dim", [128]) @pytest.mark.parametrize("groupingField", ['user_id', None]) + @pytest.mark.parametrize("sparse_format", ['dok', 'coo']) def test_search_vector_with_sparse_float_vector_datatype(self, nb, dim, insert_round, auto_id, - is_partition_key, enable_dynamic_schema, groupingField): + is_partition_key, enable_dynamic_schema, groupingField, sparse_format): """ Insert a vector with a simple payload """ @@ -879,7 +880,7 @@ class TestSearchVector(TestBase): "user_id": idx%100, "word_count": j, "book_describe": f"book_{idx}", - "sparse_float_vector": gen_vector(datatype="SparseFloatVector", dim=dim), + "sparse_float_vector": gen_vector(datatype="SparseFloatVector", dim=dim, sparse_format=sparse_format), } else: tmp = { @@ -887,7 +888,7 @@ class TestSearchVector(TestBase): "user_id": idx%100, "word_count": j, "book_describe": f"book_{idx}", - "sparse_float_vector": gen_vector(datatype="SparseFloatVector", dim=dim), + "sparse_float_vector": gen_vector(datatype="SparseFloatVector", dim=dim, sparse_format=sparse_format), } if enable_dynamic_schema: tmp.update({f"dynamic_field_{i}": i}) @@ -902,7 +903,7 @@ class TestSearchVector(TestBase): # search data payload = { "collectionName": name, - "data": [gen_vector(datatype="SparseFloatVector", dim=dim)], + "data": [gen_vector(datatype="SparseFloatVector", dim=dim, sparse_format="dok")], "filter": "word_count > 100", "outputFields": ["*"], "searchParams": { @@ -918,6 +919,24 @@ class TestSearchVector(TestBase): rsp = self.vector_client.vector_search(payload) assert rsp['code'] == 0 + # search data + payload = { + "collectionName": name, + "data": [gen_vector(datatype="SparseFloatVector", dim=dim, sparse_format="coo")], + "filter": "word_count > 100", + "outputFields": ["*"], + "searchParams": { + "metricType": "IP", + "params": { + "drop_ratio_search": "0.2", + } + }, + "limit": 500, + } + if groupingField: + payload["groupingField"] = groupingField + rsp = self.vector_client.vector_search(payload) + assert rsp['code'] == 0 @pytest.mark.parametrize("insert_round", [2]) @pytest.mark.parametrize("auto_id", [True]) diff --git a/tests/restful_client_v2/utils/utils.py b/tests/restful_client_v2/utils/utils.py index 466bf1d453..112e26e787 100644 --- a/tests/restful_client_v2/utils/utils.py +++ b/tests/restful_client_v2/utils/utils.py @@ -197,12 +197,22 @@ def gen_bf16_vectors(num, dim): return raw_vectors, bf16_vectors -def gen_vector(datatype="float_vector", dim=128, binary_data=False): +def gen_vector(datatype="float_vector", dim=128, binary_data=False, sparse_format='dok'): value = None if datatype == "FloatVector": return preprocessing.normalize([np.array([random.random() for i in range(dim)])])[0].tolist() if datatype == "SparseFloatVector": - return {d: rng.random() for d in random.sample(range(dim), random.randint(20, 30))} + if sparse_format == 'dok': + return {d: rng.random() for d in random.sample(range(dim), random.randint(20, 30))} + elif sparse_format == 'coo': + data = {d: rng.random() for d in random.sample(range(dim), random.randint(20, 30))} + coo_data = { + "indices": list(data.keys()), + "values": list(data.values()) + } + return coo_data + else: + raise Exception(f"unsupported sparse format: {sparse_format}") if datatype == "BinaryVector": value = gen_binary_vectors(1, dim)[1][0] if datatype == "Float16Vector":