mirror of
https://gitee.com/milvus-io/milvus.git
synced 2026-01-07 19:31:51 +08:00
add grpc server
This commit is contained in:
parent
052d79a58d
commit
4fc6f0a520
@ -1 +0,0 @@
|
||||
import settings
|
||||
6
mishards/__init__.py
Normal file
6
mishards/__init__.py
Normal file
@ -0,0 +1,6 @@
|
||||
import settings
|
||||
from connections import ConnectionMgr
|
||||
connect_mgr = ConnectionMgr()
|
||||
|
||||
from server import Server
|
||||
grpc_server = Server(conn_mgr=connect_mgr)
|
||||
@ -89,7 +89,7 @@ class ConnectionMgr:
|
||||
threaded = {
|
||||
threading.get_ident() : this_conn
|
||||
}
|
||||
c[name] = threaded
|
||||
self.conns[name] = threaded
|
||||
return this_conn
|
||||
|
||||
tid = threading.get_ident()
|
||||
0
mishards/grpc_utils/__init__.py
Normal file
0
mishards/grpc_utils/__init__.py
Normal file
101
mishards/grpc_utils/grpc_args_parser.py
Normal file
101
mishards/grpc_utils/grpc_args_parser.py
Normal file
@ -0,0 +1,101 @@
|
||||
from milvus import Status
|
||||
from functools import wraps
|
||||
|
||||
|
||||
def error_status(func):
|
||||
@wraps(func)
|
||||
def inner(*args, **kwargs):
|
||||
try:
|
||||
results = func(*args, **kwargs)
|
||||
except Exception as e:
|
||||
return Status(code=Status.UNEXPECTED_ERROR, message=str(e)), None
|
||||
|
||||
return Status(code=0, message="Success"), results
|
||||
|
||||
return inner
|
||||
|
||||
|
||||
class GrpcArgsParser(object):
|
||||
|
||||
@classmethod
|
||||
@error_status
|
||||
def parse_proto_TableSchema(cls, param):
|
||||
_table_schema = {
|
||||
'table_name': param.table_name.table_name,
|
||||
'dimension': param.dimension,
|
||||
'index_file_size': param.index_file_size,
|
||||
'metric_type': param.metric_type
|
||||
}
|
||||
|
||||
return _table_schema
|
||||
|
||||
@classmethod
|
||||
@error_status
|
||||
def parse_proto_TableName(cls, param):
|
||||
return param.table_name
|
||||
|
||||
@classmethod
|
||||
@error_status
|
||||
def parse_proto_Index(cls, param):
|
||||
_index = {
|
||||
'index_type': param.index_type,
|
||||
'nlist': param.nlist
|
||||
}
|
||||
|
||||
return _index
|
||||
|
||||
@classmethod
|
||||
@error_status
|
||||
def parse_proto_IndexParam(cls, param):
|
||||
_table_name = param.table_name.table_name
|
||||
_status, _index = cls.parse_proto_Index(param.index)
|
||||
|
||||
if not _status.OK():
|
||||
raise Exception("Argument parse error")
|
||||
|
||||
return _table_name, _index
|
||||
|
||||
@classmethod
|
||||
@error_status
|
||||
def parse_proto_Command(cls, param):
|
||||
_cmd = param.cmd
|
||||
|
||||
return _cmd
|
||||
|
||||
@classmethod
|
||||
@error_status
|
||||
def parse_proto_Range(cls, param):
|
||||
_start_value = param.start_value
|
||||
_end_value = param.end_value
|
||||
|
||||
return _start_value, _end_value
|
||||
|
||||
@classmethod
|
||||
@error_status
|
||||
def parse_proto_RowRecord(cls, param):
|
||||
return list(param.vector_data)
|
||||
|
||||
@classmethod
|
||||
@error_status
|
||||
def parse_proto_SearchParam(cls, param):
|
||||
_table_name = param.table_name
|
||||
_topk = param.topk
|
||||
_nprobe = param.nprobe
|
||||
_status, _range = cls.parse_proto_Range(param.query_range_array)
|
||||
|
||||
if not _status.OK():
|
||||
raise Exception("Argument parse error")
|
||||
|
||||
_row_record = param.query_record_array
|
||||
|
||||
return _table_name, _row_record, _range, _topk
|
||||
|
||||
@classmethod
|
||||
@error_status
|
||||
def parse_proto_DeleteByRangeParam(cls, param):
|
||||
_table_name = param.table_name
|
||||
_range = param.range
|
||||
_start_value = _range.start_value
|
||||
_end_value = _range.end_value
|
||||
|
||||
return _table_name, _start_value, _end_value
|
||||
4
mishards/grpc_utils/grpc_args_wrapper.py
Normal file
4
mishards/grpc_utils/grpc_args_wrapper.py
Normal file
@ -0,0 +1,4 @@
|
||||
# class GrpcArgsWrapper(object):
|
||||
|
||||
# @classmethod
|
||||
# def proto_TableName(cls):
|
||||
14
mishards/main.py
Normal file
14
mishards/main.py
Normal file
@ -0,0 +1,14 @@
|
||||
import sys
|
||||
import os
|
||||
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
||||
|
||||
import settings
|
||||
from mishards import connect_mgr, grpc_server as server
|
||||
|
||||
def main():
|
||||
connect_mgr.register('WOSERVER', settings.WOSERVER)
|
||||
server.run(port=settings.SERVER_PORT)
|
||||
return 0
|
||||
|
||||
if __name__ == '__main__':
|
||||
sys.exit(main())
|
||||
47
mishards/server.py
Normal file
47
mishards/server.py
Normal file
@ -0,0 +1,47 @@
|
||||
import logging
|
||||
import grpc
|
||||
import time
|
||||
from concurrent import futures
|
||||
from grpc._cython import cygrpc
|
||||
from milvus.grpc_gen.milvus_pb2_grpc import add_MilvusServiceServicer_to_server
|
||||
from service_handler import ServiceHandler
|
||||
import settings
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class Server:
|
||||
def __init__(self, conn_mgr, port=19530, max_workers=10, **kwargs):
|
||||
self.exit_flag = False
|
||||
self.port = int(port)
|
||||
self.conn_mgr = conn_mgr
|
||||
self.server_impl = grpc.server(
|
||||
thread_pool=futures.ThreadPoolExecutor(max_workers=max_workers),
|
||||
options=[(cygrpc.ChannelArgKey.max_send_message_length, -1),
|
||||
(cygrpc.ChannelArgKey.max_receive_message_length, -1)]
|
||||
)
|
||||
|
||||
def start(self, port=None):
|
||||
add_MilvusServiceServicer_to_server(ServiceHandler(conn_mgr=self.conn_mgr), self.server_impl)
|
||||
self.server_impl.add_insecure_port("[::]:{}".format(str(port or self._port)))
|
||||
self.server_impl.start()
|
||||
|
||||
def run(self, port):
|
||||
logger.info('Milvus server start ......')
|
||||
port = port or self.port
|
||||
|
||||
self.start(port)
|
||||
logger.info('Successfully')
|
||||
logger.info('Listening on port {}'.format(port))
|
||||
|
||||
try:
|
||||
while not self.exit_flag:
|
||||
time.sleep(5)
|
||||
except KeyboardInterrupt:
|
||||
self.stop()
|
||||
|
||||
def stop(self):
|
||||
logger.info('Server is shuting down ......')
|
||||
self.exit_flag = True
|
||||
self.server.stop(0)
|
||||
logger.info('Server is closed')
|
||||
327
mishards/service_handler.py
Normal file
327
mishards/service_handler.py
Normal file
@ -0,0 +1,327 @@
|
||||
import logging
|
||||
from contextlib import contextmanager
|
||||
from milvus.grpc_gen import milvus_pb2, milvus_pb2_grpc, status_pb2
|
||||
|
||||
from grpc_utils.grpc_args_parser import GrpcArgsParser as Parser
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class ServiceHandler(milvus_pb2_grpc.MilvusServiceServicer):
|
||||
def __init__(self, conn_mgr, *args, **kwargs):
|
||||
self.conn_mgr = conn_mgr
|
||||
self.table_meta = {}
|
||||
|
||||
@property
|
||||
def connection(self):
|
||||
conn = self.conn_mgr.conn('WOSERVER')
|
||||
if conn:
|
||||
conn.on_connect()
|
||||
return conn.conn
|
||||
|
||||
def CreateTable(self, request, context):
|
||||
_status, _table_schema = Parser.parse_proto_TableSchema(request)
|
||||
|
||||
if not _status.OK():
|
||||
return status_pb2.Status(error_code=_status.code, reason=_status.message)
|
||||
|
||||
logger.info('CreateTable {}'.format(_table_schema['table_name']))
|
||||
|
||||
_status = self.connection.create_table(_table_schema)
|
||||
|
||||
return status_pb2.Status(error_code=_status.code, reason=_status.message)
|
||||
|
||||
def HasTable(self, request, context):
|
||||
_status, _table_name = Parser.parse_proto_TableName(request)
|
||||
|
||||
if not _status.OK():
|
||||
return milvus_pb2.BoolReply(
|
||||
status=status_pb2.Status(error_code=_status.code, reason=_status.message),
|
||||
bool_reply=False
|
||||
)
|
||||
|
||||
logger.info('HasTable {}'.format(_table_name))
|
||||
|
||||
_bool = self.connection.has_table(_table_name)
|
||||
|
||||
return milvus_pb2.BoolReply(
|
||||
status=status_pb2.Status(error_code=status_pb2.SUCCESS, reason="OK"),
|
||||
bool_reply=_bool
|
||||
)
|
||||
|
||||
def DropTable(self, request, context):
|
||||
_status, _table_name = Parser.parse_proto_TableName(request)
|
||||
|
||||
if not _status.OK():
|
||||
return status_pb2.Status(error_code=_status.code, reason=_status.message)
|
||||
|
||||
logger.info('DropTable {}'.format(_table_name))
|
||||
|
||||
_status = self.connection.delete_table(_table_name)
|
||||
|
||||
return status_pb2.Status(error_code=_status.code, reason=_status.message)
|
||||
|
||||
def CreateIndex(self, request, context):
|
||||
_status, unpacks = Parser.parse_proto_IndexParam(request)
|
||||
|
||||
if not _status.OK():
|
||||
return status_pb2.Status(error_code=_status.code, reason=_status.message)
|
||||
|
||||
_table_name, _index = unpacks
|
||||
|
||||
logger.info('CreateIndex {}'.format(_table_name))
|
||||
|
||||
# TODO: interface create_table incompleted
|
||||
_status = self.connection.create_index(_table_name, _index)
|
||||
|
||||
return status_pb2.Status(error_code=_status.code, reason=_status.message)
|
||||
|
||||
def Insert(self, request, context):
|
||||
logger.info('Insert')
|
||||
# TODO: Ths SDK interface add_vectors() could update, add a key 'row_id_array'
|
||||
_status, _ids = self.connection.add_vectors(None, None, insert_param=request)
|
||||
return milvus_pb2.VectorIds(
|
||||
status=status_pb2.Status(error_code=_status.code, reason=_status.message),
|
||||
vector_id_array=_ids
|
||||
)
|
||||
|
||||
def Search(self, request, context):
|
||||
|
||||
try:
|
||||
table_name = request.table_name
|
||||
|
||||
topk = request.topk
|
||||
nprobe = request.nprobe
|
||||
|
||||
logger.info('Search {}: topk={} nprobe={}'.format(table_name, topk, nprobe))
|
||||
|
||||
if nprobe > 2048 or nprobe <= 0:
|
||||
raise exceptions.GRPCInvlidArgument('Invalid nprobe: {}'.format(nprobe))
|
||||
|
||||
table_meta = self.table_meta.get(table_name, None)
|
||||
if not table_meta:
|
||||
status, info = self.connection.describe_table(table_name)
|
||||
if not status.OK():
|
||||
raise TableNotFoundException(table_name)
|
||||
|
||||
self.table_meta[table_name] = info
|
||||
table_meta = info
|
||||
|
||||
start = time.time()
|
||||
|
||||
query_record_array = []
|
||||
|
||||
for query_record in request.query_record_array:
|
||||
query_record_array.append(list(query_record.vector_data))
|
||||
|
||||
query_range_array = []
|
||||
for query_range in request.query_range_array:
|
||||
query_range_array.append(
|
||||
Range(query_range.start_value, query_range.end_value))
|
||||
except (TableNotFoundException, exceptions.GRPCInvlidArgument) as exc:
|
||||
return milvus_pb2.TopKQueryResultList(
|
||||
status=status_pb2.Status(error_code=exc.code, reason=exc.message)
|
||||
)
|
||||
except Exception as e:
|
||||
return milvus_pb2.TopKQueryResultList(
|
||||
status=status_pb2.Status(error_code=status_pb2.UNEXPECTED_ERROR, reason=str(e))
|
||||
)
|
||||
|
||||
try:
|
||||
results = workflow.query_vectors(table_name, table_meta, query_record_array, topk,
|
||||
nprobe, query_range_array)
|
||||
except (exceptions.GRPCQueryInvalidRangeException, TableNotFoundException) as exc:
|
||||
return milvus_pb2.TopKQueryResultList(
|
||||
status=status_pb2.Status(error_code=exc.code, reason=exc.message)
|
||||
)
|
||||
except exceptions.ServiceNotFoundException as exc:
|
||||
return milvus_pb2.TopKQueryResultList(
|
||||
status=status_pb2.Status(error_code=status_pb2.UNEXPECTED_ERROR, reason=exc.message)
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(e)
|
||||
results = workflow.query_vectors(table_name, table_meta, query_record_array,
|
||||
topk, nprobe, query_range_array)
|
||||
|
||||
now = time.time()
|
||||
logger.info('SearchVector Ends @{}'.format(now))
|
||||
logger.info('SearchVector takes: {}'.format(now - start))
|
||||
|
||||
topk_result_list = milvus_pb2.TopKQueryResultList(
|
||||
status=status_pb2.Status(error_code=status_pb2.SUCCESS, reason="Success"),
|
||||
topk_query_result=results
|
||||
)
|
||||
return topk_result_list
|
||||
|
||||
def SearchInFiles(self, request, context):
|
||||
try:
|
||||
file_id_array = list(request.file_id_array)
|
||||
search_param = request.search_param
|
||||
table_name = search_param.table_name
|
||||
topk = search_param.topk
|
||||
nprobe = search_param.nprobe
|
||||
|
||||
query_record_array = []
|
||||
|
||||
for query_record in search_param.query_record_array:
|
||||
query_record_array.append(list(query_record))
|
||||
|
||||
query_range_array = []
|
||||
for query_range in search_param.query_range_array:
|
||||
query_range_array.append("")
|
||||
except Exception as e:
|
||||
milvus_pb2.TopKQueryResultList(
|
||||
status=status_pb2.Status(error_code=status_pb2.UNEXPECTED_ERROR, reason=str(e)),
|
||||
)
|
||||
|
||||
res = search_vector_in_files.delay(table_name=table_name,
|
||||
file_id_array=file_id_array,
|
||||
query_record_array=query_record_array,
|
||||
query_range_array=query_range_array,
|
||||
topk=topk,
|
||||
nprobe=nprobe)
|
||||
status, result = res.get(timeout=1)
|
||||
|
||||
if not status.OK():
|
||||
raise ThriftException(code=status.code, reason=status.message)
|
||||
res = TopKQueryResult()
|
||||
for top_k_query_results in result:
|
||||
res.query_result_arrays.append([QueryResult(id=qr.id, distance=qr.distance)
|
||||
for qr in top_k_query_results])
|
||||
return res
|
||||
|
||||
def DescribeTable(self, request, context):
|
||||
_status, _table_name = Parser.parse_proto_TableName(request)
|
||||
|
||||
if not _status.OK():
|
||||
table_name = milvus_pb2.TableName(
|
||||
status=status_pb2.Status(error_code=_status.code, reason=_status.message)
|
||||
)
|
||||
return milvus_pb2.TableSchema(
|
||||
table_name=table_name
|
||||
)
|
||||
|
||||
logger.info('DescribeTable {}'.format(_table_name))
|
||||
_status, _table = self.connection.describe_table(_table_name)
|
||||
|
||||
if _status.OK():
|
||||
_grpc_table_name = milvus_pb2.TableName(
|
||||
status=status_pb2.Status(error_code=_status.code, reason=_status.message),
|
||||
table_name=_table.table_name
|
||||
)
|
||||
|
||||
return milvus_pb2.TableSchema(
|
||||
table_name=_grpc_table_name,
|
||||
index_file_size=_table.index_file_size,
|
||||
dimension=_table.dimension,
|
||||
metric_type=_table.metric_type
|
||||
)
|
||||
|
||||
return milvus_pb2.TableSchema(
|
||||
table_name=milvus_pb2.TableName(
|
||||
status=status_pb2.Status(error_code=_status.code, reason=_status.message)
|
||||
)
|
||||
)
|
||||
|
||||
def CountTable(self, request, context):
|
||||
_status, _table_name = Parser.parse_proto_TableName(request)
|
||||
|
||||
if not _status.OK():
|
||||
status = status_pb2.Status(error_code=_status.code, reason=_status.message)
|
||||
|
||||
return milvus_pb2.TableRowCount(
|
||||
status=status
|
||||
)
|
||||
|
||||
logger.info('CountTable {}'.format(_table_name))
|
||||
|
||||
_status, _count = self.connection.get_table_row_count(_table_name)
|
||||
|
||||
return milvus_pb2.TableRowCount(
|
||||
status=status_pb2.Status(error_code=_status.code, reason=_status.message),
|
||||
table_row_count=_count if isinstance(_count, int) else -1)
|
||||
|
||||
def Cmd(self, request, context):
|
||||
_status, _cmd = Parser.parse_proto_Command(request)
|
||||
logger.info('Cmd: {}'.format(_cmd))
|
||||
|
||||
if not _status.OK():
|
||||
return milvus_pb2.StringReply(
|
||||
status_pb2.Status(error_code=_status.code, reason=_status.message)
|
||||
)
|
||||
|
||||
if _cmd == 'version':
|
||||
_status, _reply = self.connection.server_version()
|
||||
else:
|
||||
_status, _reply = self.connection.server_status()
|
||||
|
||||
return milvus_pb2.StringReply(
|
||||
status=status_pb2.Status(error_code=_status.code, reason=_status.message),
|
||||
string_reply=_reply
|
||||
)
|
||||
|
||||
def ShowTables(self, request, context):
|
||||
logger.info('ShowTables')
|
||||
_status, _results = self.connection.show_tables()
|
||||
|
||||
if not _status.OK():
|
||||
_results = []
|
||||
|
||||
for _result in _results:
|
||||
yield milvus_pb2.TableName(
|
||||
status=status_pb2.Status(error_code=_status.code, reason=_status.message),
|
||||
table_name=_result
|
||||
)
|
||||
|
||||
def DeleteByRange(self, request, context):
|
||||
_status, unpacks = \
|
||||
Parser.parse_proto_DeleteByRangeParam(request)
|
||||
|
||||
if not _status.OK():
|
||||
return status_pb2.Status(error_code=_status.code, reason=_status.message)
|
||||
|
||||
_table_name, _start_date, _end_date = unpacks
|
||||
|
||||
logger.info('DeleteByRange {}: {} {}'.format(_table_name, _start_date, _end_date))
|
||||
_status = self.connection.delete_vectors_by_range(_table_name, _start_date, _end_date)
|
||||
return status_pb2.Status(error_code=_status.code, reason=_status.message)
|
||||
|
||||
def PreloadTable(self, request, context):
|
||||
_status, _table_name = Parser.parse_proto_TableName(request)
|
||||
|
||||
if not _status.OK():
|
||||
return status_pb2.Status(error_code=_status.code, reason=_status.message)
|
||||
|
||||
logger.info('PreloadTable {}'.format(_table_name))
|
||||
_status = self.connection.preload_table(_table_name)
|
||||
return status_pb2.Status(error_code=_status.code, reason=_status.message)
|
||||
|
||||
def DescribeIndex(self, request, context):
|
||||
_status, _table_name = Parser.parse_proto_TableName(request)
|
||||
|
||||
if not _status.OK():
|
||||
return milvus_pb2.IndexParam(
|
||||
table_name=milvus_pb2.TableName(
|
||||
status=status_pb2.Status(error_code=_status.code, reason=_status.message)
|
||||
)
|
||||
)
|
||||
|
||||
logger.info('DescribeIndex {}'.format(_table_name))
|
||||
_status, _index_param = self.connection.describe_index(_table_name)
|
||||
|
||||
_index = milvus_pb2.Index(index_type=_index_param._index_type, nlist=_index_param._nlist)
|
||||
_tablename = milvus_pb2.TableName(
|
||||
status=status_pb2.Status(error_code=_status.code, reason=_status.message),
|
||||
table_name=_table_name)
|
||||
|
||||
return milvus_pb2.IndexParam(table_name=_tablename, index=_index)
|
||||
|
||||
def DropIndex(self, request, context):
|
||||
_status, _table_name = Parser.parse_proto_TableName(request)
|
||||
|
||||
if not _status.OK():
|
||||
return status_pb2.Status(error_code=_status.code, reason=_status.message)
|
||||
|
||||
logger.info('DropIndex {}'.format(_table_name))
|
||||
_status = self.connection.drop_index(_table_name)
|
||||
return status_pb2.Status(error_code=_status.code, reason=_status.message)
|
||||
@ -22,6 +22,8 @@ config(LOG_LEVEL, LOG_PATH, LOG_NAME, TIMEZONE)
|
||||
TIMEOUT = env.int('TIMEOUT', 60)
|
||||
MAX_RETRY = env.int('MAX_RETRY', 3)
|
||||
|
||||
SERVER_PORT = env.int('SERVER_PORT', 19530)
|
||||
WOSERVER = env.str('WOSERVER')
|
||||
|
||||
if __name__ == '__main__':
|
||||
import logging
|
||||
@ -1,11 +0,0 @@
|
||||
import logging
|
||||
|
||||
import grpco
|
||||
from milvus.grpc_gen import milvus_pb2, milvus_pb2_grpc, status_pb2
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class ServiceHandler(milvus_pb2_grpc.MilvusServiceServicer):
|
||||
def __init__(self, connections, *args, **kwargs):
|
||||
self.connections = self.connections
|
||||
Loading…
x
Reference in New Issue
Block a user