update for models

This commit is contained in:
peng.xu 2019-09-18 16:59:04 +08:00
parent 099317edee
commit 5646362157
9 changed files with 262 additions and 19 deletions

View File

@ -1,5 +1,6 @@
import fire
from mishards import db
from sqlalchemy import and_
class DBHandler:
@classmethod
@ -10,5 +11,17 @@ class DBHandler:
def drop_all(cls):
db.drop_all()
@classmethod
def fun(cls, tid):
from mishards.factories import TablesFactory, TableFilesFactory, Tables
f = db.Session.query(Tables).filter(and_(
Tables.table_id==tid,
Tables.state!=Tables.TO_DELETE)
).first()
print(f)
# f1 = TableFilesFactory()
if __name__ == '__main__':
fire.Fire(DBHandler)

View File

@ -2,7 +2,7 @@ from mishards import settings
from mishards.db_base import DB
db = DB()
db.init_db(uri=settings.SQLALCHEMY_DATABASE_URI)
db.init_db(uri=settings.SQLALCHEMY_DATABASE_URI, echo=settings.SQL_ECHO)
from mishards.connections import ConnectionMgr
connect_mgr = ConnectionMgr()

View File

@ -1,15 +1,20 @@
import logging
from sqlalchemy import create_engine
from sqlalchemy.ext.declarative import declarative_base
from sqlalchemy.orm import sessionmaker, scoped_session
logger = logging.getLogger(__name__)
class DB:
Model = declarative_base()
def __init__(self, uri=None):
uri and self.init_db(uri)
def __init__(self, uri=None, echo=False):
self.echo = echo
uri and self.init_db(uri, echo)
def init_db(self, uri):
def init_db(self, uri, echo=False):
self.engine = create_engine(uri, pool_size=100, pool_recycle=5, pool_timeout=30,
pool_pre_ping=True,
echo=echo,
max_overflow=0)
self.uri = uri
session = sessionmaker()

49
mishards/factories.py Normal file
View File

@ -0,0 +1,49 @@
import time
import datetime
import random
import factory
from factory.alchemy import SQLAlchemyModelFactory
from faker import Faker
from faker.providers import BaseProvider
from mishards import db
from mishards.models import Tables, TableFiles
class FakerProvider(BaseProvider):
def this_date(self):
t = datetime.datetime.today()
return (t.year - 1900) * 10000 + (t.month-1)*100 + t.day
factory.Faker.add_provider(FakerProvider)
class TablesFactory(SQLAlchemyModelFactory):
class Meta:
model = Tables
sqlalchemy_session = db.Session
sqlalchemy_session_persistence = 'commit'
id = factory.Faker('random_number', digits=16, fix_len=True)
table_id = factory.Faker('uuid4')
state = factory.Faker('random_element', elements=(0,1,2,3))
dimension = factory.Faker('random_element', elements=(256,512))
created_on = int(time.time())
index_file_size = 0
engine_type = factory.Faker('random_element', elements=(0,1,2,3))
metric_type = factory.Faker('random_element', elements=(0,1))
nlist = 16384
class TableFilesFactory(SQLAlchemyModelFactory):
class Meta:
model = TableFiles
sqlalchemy_session = db.Session
sqlalchemy_session_persistence = 'commit'
id = factory.Faker('random_number', digits=16, fix_len=True)
table = factory.SubFactory(TablesFactory)
engine_type = factory.Faker('random_element', elements=(0,1,2,3))
file_id = factory.Faker('uuid4')
file_type = factory.Faker('random_element', elements=(0,1,2,3,4))
file_size = factory.Faker('random_number')
updated_time = int(time.time())
created_on = int(time.time())
date = factory.Faker('this_date')

150
mishards/hash_ring.py Normal file
View File

@ -0,0 +1,150 @@
import math
import sys
from bisect import bisect
if sys.version_info >= (2, 5):
import hashlib
md5_constructor = hashlib.md5
else:
import md5
md5_constructor = md5.new
class HashRing(object):
def __init__(self, nodes=None, weights=None):
"""`nodes` is a list of objects that have a proper __str__ representation.
`weights` is dictionary that sets weights to the nodes. The default
weight is that all nodes are equal.
"""
self.ring = dict()
self._sorted_keys = []
self.nodes = nodes
if not weights:
weights = {}
self.weights = weights
self._generate_circle()
def _generate_circle(self):
"""Generates the circle.
"""
total_weight = 0
for node in self.nodes:
total_weight += self.weights.get(node, 1)
for node in self.nodes:
weight = 1
if node in self.weights:
weight = self.weights.get(node)
factor = math.floor((40*len(self.nodes)*weight) / total_weight);
for j in range(0, int(factor)):
b_key = self._hash_digest( '%s-%s' % (node, j) )
for i in range(0, 3):
key = self._hash_val(b_key, lambda x: x+i*4)
self.ring[key] = node
self._sorted_keys.append(key)
self._sorted_keys.sort()
def get_node(self, string_key):
"""Given a string key a corresponding node in the hash ring is returned.
If the hash ring is empty, `None` is returned.
"""
pos = self.get_node_pos(string_key)
if pos is None:
return None
return self.ring[ self._sorted_keys[pos] ]
def get_node_pos(self, string_key):
"""Given a string key a corresponding node in the hash ring is returned
along with it's position in the ring.
If the hash ring is empty, (`None`, `None`) is returned.
"""
if not self.ring:
return None
key = self.gen_key(string_key)
nodes = self._sorted_keys
pos = bisect(nodes, key)
if pos == len(nodes):
return 0
else:
return pos
def iterate_nodes(self, string_key, distinct=True):
"""Given a string key it returns the nodes as a generator that can hold the key.
The generator iterates one time through the ring
starting at the correct position.
if `distinct` is set, then the nodes returned will be unique,
i.e. no virtual copies will be returned.
"""
if not self.ring:
yield None, None
returned_values = set()
def distinct_filter(value):
if str(value) not in returned_values:
returned_values.add(str(value))
return value
pos = self.get_node_pos(string_key)
for key in self._sorted_keys[pos:]:
val = distinct_filter(self.ring[key])
if val:
yield val
for i, key in enumerate(self._sorted_keys):
if i < pos:
val = distinct_filter(self.ring[key])
if val:
yield val
def gen_key(self, key):
"""Given a string key it returns a long value,
this long value represents a place on the hash ring.
md5 is currently used because it mixes well.
"""
b_key = self._hash_digest(key)
return self._hash_val(b_key, lambda x: x)
def _hash_val(self, b_key, entry_fn):
return (( b_key[entry_fn(3)] << 24)
|(b_key[entry_fn(2)] << 16)
|(b_key[entry_fn(1)] << 8)
| b_key[entry_fn(0)] )
def _hash_digest(self, key):
m = md5_constructor()
key = key.encode()
m.update(key)
return m.digest()
if __name__ == '__main__':
from collections import defaultdict
servers = ['192.168.0.246:11212',
'192.168.0.247:11212',
'192.168.0.248:11212',
'192.168.0.249:11212']
ring = HashRing(servers)
keys = ['{}'.format(i) for i in range(100)]
mapped = defaultdict(list)
for k in keys:
server = ring.get_node(k)
mapped[server].append(k)
for k,v in mapped.items():
print(k, v)

View File

@ -32,8 +32,8 @@ class TableFiles(db.Model):
date = Column(Integer)
table = relationship(
'Table',
primaryjoin='and_(foreign(TableFile.table_id) == Table.table_id)',
'Tables',
primaryjoin='and_(foreign(TableFiles.table_id) == Tables.table_id)',
backref=backref('files', uselist=True, lazy='dynamic')
)
@ -57,15 +57,15 @@ class Tables(db.Model):
def files_to_search(self, date_range=None):
cond = or_(
TableFile.file_type==TableFile.FILE_TYPE_RAW,
TableFile.file_type==TableFile.FILE_TYPE_TO_INDEX,
TableFile.file_type==TableFile.FILE_TYPE_INDEX,
TableFiles.file_type==TableFiles.FILE_TYPE_RAW,
TableFiles.file_type==TableFiles.FILE_TYPE_TO_INDEX,
TableFiles.file_type==TableFiles.FILE_TYPE_INDEX,
)
if date_range:
cond = and_(
cond,
or_(
and_(TableFile.date>=d[0], TableFile.date<d[1]) for d in date_range
and_(TableFiles.date>=d[0], TableFiles.date<d[1]) for d in date_range
)
)

View File

@ -250,8 +250,8 @@ class ServiceFounder(object):
self.listener.daemon = True
self.listener.start()
self.event_handler.start()
while self.listener.at_start_up:
time.sleep(1)
# while self.listener.at_start_up:
# time.sleep(1)
self.pod_heartbeater.start()

View File

@ -4,14 +4,17 @@ import datetime
from contextlib import contextmanager
from collections import defaultdict
from sqlalchemy import and_
from concurrent.futures import ThreadPoolExecutor
from milvus.grpc_gen import milvus_pb2, milvus_pb2_grpc, status_pb2
from milvus.grpc_gen.milvus_pb2 import TopKQueryResult
from milvus.client import types
from mishards import (settings, exceptions)
from mishards import (db, settings, exceptions)
from mishards.grpc_utils.grpc_args_parser import GrpcArgsParser as Parser
from mishards.models import Tables, TableFiles
from mishards.hash_ring import HashRing
logger = logging.getLogger(__name__)
@ -53,12 +56,34 @@ class ServiceHandler(milvus_pb2_grpc.MilvusServiceServicer):
return self._format_date(start, end)
def _get_routing_file_ids(self, table_id, range_array):
return {
'milvus-ro-servers-0': {
'table_id': table_id,
'file_ids': [123]
table = db.Session.query(Tables).filter(and_(
Tables.table_id==table_id,
Tables.state!=Tables.TO_DELETE
)).first()
logger.error(table)
if not table:
raise exceptions.TableNotFoundError(table_id)
files = table.files_to_search(range_array)
servers = self.conn_mgr.conn_names
logger.info('Available servers: {}'.format(servers))
ring = HashRing(servers)
routing = {}
for f in files:
target_host = ring.get_node(str(f.id))
sub = routing.get(target_host, None)
if not sub:
routing[target_host] = {
'table_id': table_id,
'file_ids': []
}
}
routing[target_host]['file_ids'].append(str(f.id))
return routing
def _do_merge(self, files_n_topk_results, topk, reverse=False):
if not files_n_topk_results:
@ -88,7 +113,7 @@ class ServiceHandler(milvus_pb2_grpc.MilvusServiceServicer):
def _do_query(self, table_id, table_meta, vectors, topk, nprobe, range_array=None, **kwargs):
range_array = [self._range_to_date(r) for r in range_array] if range_array else None
routing = self._get_routing_file_ids(table_id, range_array)
logger.debug(routing)
logger.info('Routing: {}'.format(routing))
rs = []
all_topk_results = []

View File

@ -19,6 +19,7 @@ from mishards.utils.logger_helper import config
config(LOG_LEVEL, LOG_PATH, LOG_NAME, TIMEZONE)
SQLALCHEMY_DATABASE_URI = env.str('SQLALCHEMY_DATABASE_URI')
SQL_ECHO = env.bool('SQL_ECHO', False)
TIMEOUT = env.int('TIMEOUT', 60)
MAX_RETRY = env.int('MAX_RETRY', 3)