mirror of
https://gitee.com/milvus-io/milvus.git
synced 2026-01-07 19:31:51 +08:00
impl part of search
This commit is contained in:
parent
4fc6f0a520
commit
86a893cb04
@ -2,3 +2,5 @@ INVALID_CODE = -1
|
||||
|
||||
CONNECT_ERROR_CODE = 10001
|
||||
CONNECTTION_NOT_FOUND_CODE = 10002
|
||||
|
||||
TABLE_NOT_FOUND_CODE = 20001
|
||||
|
||||
@ -11,3 +11,6 @@ class ConnectionConnectError(BaseException):
|
||||
|
||||
class ConnectionNotFoundError(BaseException):
|
||||
code = codes.CONNECTTION_NOT_FOUND_CODE
|
||||
|
||||
class TableNotFoundError(BaseException):
|
||||
code = codes.TABLE_NOT_FOUND_CODE
|
||||
|
||||
@ -7,6 +7,7 @@ from mishards import connect_mgr, grpc_server as server
|
||||
|
||||
def main():
|
||||
connect_mgr.register('WOSERVER', settings.WOSERVER)
|
||||
connect_mgr.register('TEST', 'tcp://127.0.0.1:19530')
|
||||
server.run(port=settings.SERVER_PORT)
|
||||
return 0
|
||||
|
||||
|
||||
@ -1,13 +1,22 @@
|
||||
import logging
|
||||
import time
|
||||
import datetime
|
||||
from contextlib import contextmanager
|
||||
from milvus.grpc_gen import milvus_pb2, milvus_pb2_grpc, status_pb2
|
||||
from collections import defaultdict
|
||||
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
from milvus.grpc_gen import milvus_pb2, milvus_pb2_grpc, status_pb2
|
||||
from milvus.grpc_gen.milvus_pb2 import TopKQueryResult
|
||||
from milvus.client import types
|
||||
|
||||
import settings
|
||||
from grpc_utils.grpc_args_parser import GrpcArgsParser as Parser
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class ServiceHandler(milvus_pb2_grpc.MilvusServiceServicer):
|
||||
MAX_NPROBE = 2048
|
||||
def __init__(self, conn_mgr, *args, **kwargs):
|
||||
self.conn_mgr = conn_mgr
|
||||
self.table_meta = {}
|
||||
@ -19,6 +28,99 @@ class ServiceHandler(milvus_pb2_grpc.MilvusServiceServicer):
|
||||
conn.on_connect()
|
||||
return conn.conn
|
||||
|
||||
def query_conn(self, name):
|
||||
conn = self.conn_mgr.conn(name)
|
||||
conn and conn.on_connect()
|
||||
return conn.conn
|
||||
|
||||
def _format_date(self, start, end):
|
||||
return ((start.year-1900)*10000 + (start.month-1)*100 + start.day
|
||||
, (end.year-1900)*10000 + (end.month-1)*100 + end.day)
|
||||
|
||||
def _range_to_date(self, range_obj):
|
||||
try:
|
||||
start = datetime.datetime.strptime(range_obj.start_date, '%Y-%m-%d')
|
||||
end = datetime.datetime.strptime(range_obj.end_date, '%Y-%m-%d')
|
||||
assert start >= end
|
||||
except (ValueError, AssertionError):
|
||||
raise exceptions.InvalidRangeError('Invalid time range: {} {}'.format(
|
||||
range_obj.start_date, range_obj.end_date
|
||||
))
|
||||
|
||||
return self._format_date(start, end)
|
||||
|
||||
def _get_routing_file_ids(self, table_id, range_array):
|
||||
return {
|
||||
'TEST': {
|
||||
'table_id': table_id,
|
||||
'file_ids': [123]
|
||||
}
|
||||
}
|
||||
|
||||
def _do_merge(self, files_n_topk_results, topk, reverse=False):
|
||||
if not files_n_topk_results:
|
||||
return []
|
||||
|
||||
request_results = defaultdict(list)
|
||||
|
||||
calc_time = time.time()
|
||||
for files_collection in files_n_topk_results:
|
||||
for request_pos, each_request_results in enumerate(files_collection.topk_query_result):
|
||||
request_results[request_pos].extend(each_request_results.query_result_arrays)
|
||||
request_results[request_pos] = sorted(request_results[request_pos], key=lambda x: x.distance,
|
||||
reverse=reverse)[:topk]
|
||||
|
||||
calc_time = time.time() - calc_time
|
||||
logger.info('Merge takes {}'.format(calc_time))
|
||||
|
||||
results = sorted(request_results.items())
|
||||
topk_query_result = []
|
||||
|
||||
for result in results:
|
||||
query_result = TopKQueryResult(query_result_arrays=result[1])
|
||||
topk_query_result.append(query_result)
|
||||
|
||||
return topk_query_result
|
||||
|
||||
def _do_query(self, table_id, table_meta, vectors, topk, nprobe, range_array=None, **kwargs):
|
||||
range_array = [self._range_to_date(r) for r in range_array] if range_array else None
|
||||
routing = self._get_routing_file_ids(table_id, range_array)
|
||||
logger.debug(routing)
|
||||
|
||||
rs = []
|
||||
all_topk_results = []
|
||||
|
||||
workers = settings.SEARCH_WORKER_SIZE
|
||||
|
||||
def search(addr, query_params, vectors, topk, nprobe, **kwargs):
|
||||
logger.info('Send Search Request: addr={};params={};nq={};topk={};nprobe={}'.format(
|
||||
addr, query_params, len(vectors), topk, nprobe
|
||||
))
|
||||
|
||||
conn = self.query_conn(addr)
|
||||
start = time.time()
|
||||
ret = conn.search_vectors_in_files(table_name=query_params['table_id'],
|
||||
file_ids=query_params['file_ids'],
|
||||
query_records=vectors,
|
||||
top_k=topk,
|
||||
nprobe=nprobe,
|
||||
lazy=True)
|
||||
end = time.time()
|
||||
logger.info('search_vectors_in_files takes: {}'.format(end - start))
|
||||
|
||||
all_topk_results.append(ret)
|
||||
|
||||
with ThreadPoolExecutor(max_workers=workers) as pool:
|
||||
for addr, params in routing.items():
|
||||
res = pool.submit(search, addr, params, vectors, topk, nprobe)
|
||||
rs.append(res)
|
||||
|
||||
for res in rs:
|
||||
res.result()
|
||||
|
||||
reverse = table_meta.metric_type == types.MetricType.L2
|
||||
return self._do_merge(all_topk_results, topk, reverse=reverse)
|
||||
|
||||
def CreateTable(self, request, context):
|
||||
_status, _table_schema = Parser.parse_proto_TableSchema(request)
|
||||
|
||||
@ -87,64 +189,64 @@ class ServiceHandler(milvus_pb2_grpc.MilvusServiceServicer):
|
||||
|
||||
def Search(self, request, context):
|
||||
|
||||
try:
|
||||
table_name = request.table_name
|
||||
table_name = request.table_name
|
||||
|
||||
topk = request.topk
|
||||
nprobe = request.nprobe
|
||||
topk = request.topk
|
||||
nprobe = request.nprobe
|
||||
|
||||
logger.info('Search {}: topk={} nprobe={}'.format(table_name, topk, nprobe))
|
||||
logger.info('Search {}: topk={} nprobe={}'.format(table_name, topk, nprobe))
|
||||
|
||||
if nprobe > 2048 or nprobe <= 0:
|
||||
raise exceptions.GRPCInvlidArgument('Invalid nprobe: {}'.format(nprobe))
|
||||
if nprobe > self.MAX_NPROBE or nprobe <= 0:
|
||||
raise exceptions.GRPCInvlidArgument('Invalid nprobe: {}'.format(nprobe))
|
||||
|
||||
table_meta = self.table_meta.get(table_name, None)
|
||||
if not table_meta:
|
||||
status, info = self.connection.describe_table(table_name)
|
||||
if not status.OK():
|
||||
raise TableNotFoundException(table_name)
|
||||
table_meta = self.table_meta.get(table_name, None)
|
||||
if not table_meta:
|
||||
status, info = self.connection.describe_table(table_name)
|
||||
if not status.OK():
|
||||
raise exceptions.TableNotFoundError(table_name)
|
||||
|
||||
self.table_meta[table_name] = info
|
||||
table_meta = info
|
||||
self.table_meta[table_name] = info
|
||||
table_meta = info
|
||||
|
||||
start = time.time()
|
||||
start = time.time()
|
||||
|
||||
query_record_array = []
|
||||
query_record_array = []
|
||||
|
||||
for query_record in request.query_record_array:
|
||||
query_record_array.append(list(query_record.vector_data))
|
||||
for query_record in request.query_record_array:
|
||||
query_record_array.append(list(query_record.vector_data))
|
||||
|
||||
query_range_array = []
|
||||
for query_range in request.query_range_array:
|
||||
query_range_array.append(
|
||||
Range(query_range.start_value, query_range.end_value))
|
||||
except (TableNotFoundException, exceptions.GRPCInvlidArgument) as exc:
|
||||
return milvus_pb2.TopKQueryResultList(
|
||||
status=status_pb2.Status(error_code=exc.code, reason=exc.message)
|
||||
)
|
||||
except Exception as e:
|
||||
return milvus_pb2.TopKQueryResultList(
|
||||
status=status_pb2.Status(error_code=status_pb2.UNEXPECTED_ERROR, reason=str(e))
|
||||
)
|
||||
query_range_array = []
|
||||
for query_range in request.query_range_array:
|
||||
query_range_array.append(
|
||||
Range(query_range.start_value, query_range.end_value))
|
||||
# except (TableNotFoundException, exceptions.GRPCInvlidArgument) as exc:
|
||||
# return milvus_pb2.TopKQueryResultList(
|
||||
# status=status_pb2.Status(error_code=exc.code, reason=exc.message)
|
||||
# )
|
||||
# except Exception as e:
|
||||
# return milvus_pb2.TopKQueryResultList(
|
||||
# status=status_pb2.Status(error_code=status_pb2.UNEXPECTED_ERROR, reason=str(e))
|
||||
# )
|
||||
|
||||
try:
|
||||
results = workflow.query_vectors(table_name, table_meta, query_record_array, topk,
|
||||
nprobe, query_range_array)
|
||||
except (exceptions.GRPCQueryInvalidRangeException, TableNotFoundException) as exc:
|
||||
return milvus_pb2.TopKQueryResultList(
|
||||
status=status_pb2.Status(error_code=exc.code, reason=exc.message)
|
||||
)
|
||||
except exceptions.ServiceNotFoundException as exc:
|
||||
return milvus_pb2.TopKQueryResultList(
|
||||
status=status_pb2.Status(error_code=status_pb2.UNEXPECTED_ERROR, reason=exc.message)
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(e)
|
||||
results = workflow.query_vectors(table_name, table_meta, query_record_array,
|
||||
topk, nprobe, query_range_array)
|
||||
results = self._do_query(table_name, table_meta, query_record_array, topk,
|
||||
nprobe, query_range_array)
|
||||
# try:
|
||||
# results = workflow.query_vectors(table_name, table_meta, query_record_array, topk,
|
||||
# nprobe, query_range_array)
|
||||
# except (exceptions.GRPCQueryInvalidRangeException, TableNotFoundException) as exc:
|
||||
# return milvus_pb2.TopKQueryResultList(
|
||||
# status=status_pb2.Status(error_code=exc.code, reason=exc.message)
|
||||
# )
|
||||
# except exceptions.ServiceNotFoundException as exc:
|
||||
# return milvus_pb2.TopKQueryResultList(
|
||||
# status=status_pb2.Status(error_code=status_pb2.UNEXPECTED_ERROR, reason=exc.message)
|
||||
# )
|
||||
# except Exception as e:
|
||||
# logger.error(e)
|
||||
# results = workflow.query_vectors(table_name, table_meta, query_record_array,
|
||||
# topk, nprobe, query_range_array)
|
||||
|
||||
now = time.time()
|
||||
logger.info('SearchVector Ends @{}'.format(now))
|
||||
logger.info('SearchVector takes: {}'.format(now - start))
|
||||
|
||||
topk_result_list = milvus_pb2.TopKQueryResultList(
|
||||
@ -154,41 +256,7 @@ class ServiceHandler(milvus_pb2_grpc.MilvusServiceServicer):
|
||||
return topk_result_list
|
||||
|
||||
def SearchInFiles(self, request, context):
|
||||
try:
|
||||
file_id_array = list(request.file_id_array)
|
||||
search_param = request.search_param
|
||||
table_name = search_param.table_name
|
||||
topk = search_param.topk
|
||||
nprobe = search_param.nprobe
|
||||
|
||||
query_record_array = []
|
||||
|
||||
for query_record in search_param.query_record_array:
|
||||
query_record_array.append(list(query_record))
|
||||
|
||||
query_range_array = []
|
||||
for query_range in search_param.query_range_array:
|
||||
query_range_array.append("")
|
||||
except Exception as e:
|
||||
milvus_pb2.TopKQueryResultList(
|
||||
status=status_pb2.Status(error_code=status_pb2.UNEXPECTED_ERROR, reason=str(e)),
|
||||
)
|
||||
|
||||
res = search_vector_in_files.delay(table_name=table_name,
|
||||
file_id_array=file_id_array,
|
||||
query_record_array=query_record_array,
|
||||
query_range_array=query_range_array,
|
||||
topk=topk,
|
||||
nprobe=nprobe)
|
||||
status, result = res.get(timeout=1)
|
||||
|
||||
if not status.OK():
|
||||
raise ThriftException(code=status.code, reason=status.message)
|
||||
res = TopKQueryResult()
|
||||
for top_k_query_results in result:
|
||||
res.query_result_arrays.append([QueryResult(id=qr.id, distance=qr.distance)
|
||||
for qr in top_k_query_results])
|
||||
return res
|
||||
raise NotImplemented()
|
||||
|
||||
def DescribeTable(self, request, context):
|
||||
_status, _table_name = Parser.parse_proto_TableName(request)
|
||||
|
||||
@ -21,6 +21,7 @@ config(LOG_LEVEL, LOG_PATH, LOG_NAME, TIMEZONE)
|
||||
|
||||
TIMEOUT = env.int('TIMEOUT', 60)
|
||||
MAX_RETRY = env.int('MAX_RETRY', 3)
|
||||
SEARCH_WORKER_SIZE = env.int('SEARCH_WORKER_SIZE', 10)
|
||||
|
||||
SERVER_PORT = env.int('SERVER_PORT', 19530)
|
||||
WOSERVER = env.str('WOSERVER')
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user