diff --git a/CHANGELOG.md b/CHANGELOG.md index adaf650a17..11a931dfda 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -9,6 +9,7 @@ Please mark all change in change log and use the issue from GitHub - \#2952 Fix the result merging of IVF_PQ IP - \#2975 Fix config UT failed - \#3012 If the cache is too small, queries using multiple GPUs will cause to crash +- \#3133 Reverse query result in mishards if metric type is IP ## Feature diff --git a/shards/mishards/connections.py b/shards/mishards/connections.py index cff5b39066..6d1c668ddd 100644 --- a/shards/mishards/connections.py +++ b/shards/mishards/connections.py @@ -224,6 +224,16 @@ logger = logging.getLogger(__name__) # connection = Connection(name=self.name, uri=self.uri, max_retry=self.max_retry, **self.kwargs) # return connection +def version_supported(version): + version_pattern = lambda v : ".".join(v.split(".")[:2]) + + sv_patterns = set() + for supported_version in settings.SERVER_VERSIONS: + sv_patterns.add(version_pattern(supported_version)) + + v_pattern = version_pattern(version) + return v_pattern in sv_patterns + class ConnectionGroup(topology.TopoGroup): def __init__(self, name): @@ -243,7 +253,7 @@ class ConnectionGroup(topology.TopoGroup): if not status.OK(): logger.error('Cannot connect to newly added address: {}. Remove it now'.format(topo_object.name)) return False - if version not in settings.SERVER_VERSIONS: + if not version_supported(version): logger.error('Cannot connect to server of version: {}. Only {} supported'.format(version, settings.SERVER_VERSIONS)) return False diff --git a/shards/mishards/service_handler.py b/shards/mishards/service_handler.py index e6bb749e85..e64f1f1fb0 100644 --- a/shards/mishards/service_handler.py +++ b/shards/mishards/service_handler.py @@ -27,14 +27,15 @@ class ServiceHandler(milvus_pb2_grpc.MilvusServiceServicer): self.max_workers = max_workers def _reduce(self, source_ids, ids, source_diss, diss, k, reverse): - if source_diss[k - 1] <= diss[0]: + sort_f = lambda x, y: x >= y if reverse else lambda x, y: x <= y + if sort_f(source_diss[k - 1], diss[0]): return source_ids, source_diss - if diss[k - 1] <= source_diss[0]: + if sort_f(diss[k - 1], source_diss[0]): return ids, diss source_diss.extend(diss) diss_t = enumerate(source_diss) - diss_m_rst = sorted(diss_t, key=lambda x: x[1])[:k] + diss_m_rst = sorted(diss_t, key=lambda x: x[1], reverse=reverse)[:k] diss_m_out = [id_ for _, id_ in diss_m_rst] source_ids.extend(ids) @@ -149,9 +150,9 @@ class ServiceHandler(milvus_pb2_grpc.MilvusServiceServicer): params=search_params, _async=True) futures.append(future) - for f in futures: - ret = f.result(raw=True) - all_topk_results.append(ret) + for f in futures: + ret = f.result(raw=True) + all_topk_results.append(ret) reverse = collection_meta.metric_type == Types.MetricType.IP with self.tracer.start_span('do_merge', child_of=p_span): diff --git a/shards/mishards/settings.py b/shards/mishards/settings.py index 8530f45c84..ca24d1dd3e 100644 --- a/shards/mishards/settings.py +++ b/shards/mishards/settings.py @@ -12,7 +12,7 @@ else: env.read_env() -SERVER_VERSIONS = ['0.9.0', '0.9.1', '0.10.0', '0.10.1'] +SERVER_VERSIONS = ['0.9.x', '0.10.x'] DEBUG = env.bool('DEBUG', False) MAX_RETRY = env.int('MAX_RETRY', 3)