未验证 提交 5f2f8bdc 编写于 作者: X XuPeng-SH 提交者: GitHub

[skip ci] (shards) Upgrade Mishards for #1569 (#1570)

* [skip ci](shards): export MAX_WORKERS as configurable parameter
Signed-off-by: Npeng.xu <peng.xu@zilliz.com>

* [skip ci](shards): skip mishards .env git info
Signed-off-by: Npeng.xu <peng.xu@zilliz.com>

* [skip ci](shards): support more robust static discovery host configuration
Signed-off-by: Npeng.xu <peng.xu@zilliz.com>

* [skip ci](shards): update static provider that terminate server if connection to downstream server error during startup
Signed-off-by: Npeng.xu <peng.xu@zilliz.com>

* [skip ci](shards): add topology.py
Signed-off-by: Npeng.xu <peng.xu@zilliz.com>

* [skip ci](shards): add connection pool
Signed-off-by: Npeng.xu <peng.xu@zilliz.com>

* [skip ci](shards): add topology test
Signed-off-by: Npeng.xu <peng.xu@zilliz.com>

* [skip ci](shards): refactory using topo
Signed-off-by: Npeng.xu <peng.xu@zilliz.com>

* [skip ci](shards): refactory static discovery using topo
Signed-off-by: Npeng.xu <peng.xu@zilliz.com>

* [skip ci](shards): refactory kubernetes discovery using topo
Signed-off-by: Npeng.xu <peng.xu@zilliz.com>

* [skip ci](shards): add more test for connection pool
Signed-off-by: Npeng.xu <peng.xu@zilliz.com>

* [skip ci](shards): export 19541 and 19542 for all_in_one demo
Signed-off-by: Npeng.xu <peng.xu@zilliz.com>

* [skip ci](shards): check version on new connection
Signed-off-by: Npeng.xu <peng.xu@zilliz.com>

* [skip ci](shards): mock connections
Signed-off-by: Npeng.xu <peng.xu@zilliz.com>

* [skip ci](shards): update tests
Signed-off-by: Npeng.xu <peng.xu@zilliz.com>
上级 4088f5e9
......@@ -31,3 +31,4 @@ cov_html/
# temp
shards/all_in_one_with_mysql/metadata/
shards/mishards/.env
......@@ -4,6 +4,8 @@ services:
runtime: nvidia
restart: always
image: milvusdb/milvus:0.6.0-gpu-d120719-2b40dd
ports:
- "0.0.0.0:19540:19530"
volumes:
- /tmp/milvus/db:/var/lib/milvus/db
- ./wr_server.yml:/opt/milvus/conf/server_config.yaml
......@@ -12,6 +14,8 @@ services:
runtime: nvidia
restart: always
image: milvusdb/milvus:0.6.0-gpu-d120719-2b40dd
ports:
- "0.0.0.0:19541:19530"
volumes:
- /tmp/milvus/db:/var/lib/milvus/db
- ./ro_server.yml:/opt/milvus/conf/server_config.yaml
......
......@@ -2,8 +2,10 @@ import os
import logging
import pytest
import grpc
import mock
import tempfile
import shutil
import time
from mishards import settings, db, create_app
logger = logging.getLogger(__name__)
......@@ -18,6 +20,9 @@ settings.TestingConfig.SQLALCHEMY_DATABASE_URI = 'sqlite:///{}?check_same_thread
@pytest.fixture
def app(request):
from mishards.connections import ConnectionGroup
ConnectionGroup.on_pre_add = mock.MagicMock(return_value=(True,))
time.sleep(0.1)
app = create_app(settings.TestingConfig)
db.drop_all()
db.create_all()
......
......@@ -13,10 +13,10 @@ class DiscoveryFactory(BaseMixin):
super().__init__(searchpath=searchpath, package_name=PLUGIN_PACKAGE_NAME)
def _create(self, plugin_class, **kwargs):
conn_mgr = kwargs.pop('conn_mgr', None)
if not conn_mgr:
raise RuntimeError('Please pass conn_mgr to create discovery!')
readonly_topo = kwargs.pop('readonly_topo', None)
if not readonly_topo:
raise RuntimeError('Please pass readonly_topo to create discovery!')
plugin_config = DiscoveryConfig.Create()
plugin = plugin_class.Create(plugin_config=plugin_config, conn_mgr=conn_mgr, **kwargs)
plugin = plugin_class.Create(plugin_config=plugin_config, readonly_topo=readonly_topo, **kwargs)
return plugin
......@@ -181,7 +181,7 @@ class EventHandler(threading.Thread):
self.mgr.delete_pod(name=event['pod'])
def on_pod_heartbeat(self, event, **kwargs):
names = self.mgr.conn_mgr.conn_names
names = self.mgr.readonly_topo.group_names
running_names = set()
for each_event in event['events']:
......@@ -195,7 +195,7 @@ class EventHandler(threading.Thread):
for name in to_delete:
self.mgr.delete_pod(name)
logger.info(self.mgr.conn_mgr.conn_names)
logger.info(self.mgr.readonly_topo.group_names)
def handle_event(self, event):
if event['eType'] == EventType.PodHeartBeat:
......@@ -237,7 +237,7 @@ class KubernetesProviderSettings:
class KubernetesProvider(object):
name = 'kubernetes'
def __init__(self, plugin_config, conn_mgr, **kwargs):
def __init__(self, plugin_config, readonly_topo, **kwargs):
self.namespace = plugin_config.DISCOVERY_KUBERNETES_NAMESPACE
self.pod_patt = plugin_config.DISCOVERY_KUBERNETES_POD_PATT
self.label_selector = plugin_config.DISCOVERY_KUBERNETES_LABEL_SELECTOR
......@@ -250,7 +250,7 @@ class KubernetesProvider(object):
self.kwargs = kwargs
self.queue = queue.Queue()
self.conn_mgr = conn_mgr
self.readonly_topo = readonly_topo
if not self.namespace:
self.namespace = open(incluster_namespace_path).read()
......@@ -281,10 +281,24 @@ class KubernetesProvider(object):
**kwargs)
def add_pod(self, name, ip):
self.conn_mgr.register(name, 'tcp://{}:{}'.format(ip, self.port))
ok = True
status = StatusType.OK
try:
uri = 'tcp://{}:{}'.format(ip, self.port)
status, group = self.readonly_topo.create(name=name)
if status == StatusType.OK:
status, pool = group.create(name=name, uri=uri)
except ConnectionConnectError as exc:
ok = False
logger.error('Connection error to: {}'.format(addr))
if ok and status == StatusType.OK:
logger.info('KubernetesProvider Add Group \"{}\" Of 1 Address: {}'.format(name, uri))
return ok
def delete_pod(self, name):
self.conn_mgr.unregister(name)
pool = self.readonly_topo.delete_group(name)
return True
def start(self):
self.listener.daemon = True
......@@ -299,8 +313,8 @@ class KubernetesProvider(object):
self.event_handler.stop()
@classmethod
def Create(cls, conn_mgr, plugin_config, **kwargs):
discovery = cls(plugin_config=plugin_config, conn_mgr=conn_mgr, **kwargs)
def Create(cls, readonly_topo, plugin_config, **kwargs):
discovery = cls(config=plugin_config, readonly_topo=readonly_topo, **kwargs)
return discovery
......
......@@ -6,37 +6,72 @@ if __name__ == '__main__':
import logging
import socket
from environs import Env
from mishards.exceptions import ConnectionConnectError
from mishards.topology import StatusType
logger = logging.getLogger(__name__)
env = Env()
DELIMITER = ':'
def parse_host(addr):
splited_arr = addr.split(DELIMITER)
return splited_arr
def resolve_address(addr, default_port):
addr_arr = parse_host(addr)
assert len(addr_arr) >= 1 and len(addr_arr) <= 2, 'Invalid Addr: {}'.format(addr)
port = addr_arr[1] if len(addr_arr) == 2 else default_port
return '{}:{}'.format(socket.gethostbyname(addr_arr[0]), port)
class StaticDiscovery(object):
name = 'static'
def __init__(self, config, conn_mgr, **kwargs):
self.conn_mgr = conn_mgr
def __init__(self, config, readonly_topo, **kwargs):
self.readonly_topo = readonly_topo
hosts = env.list('DISCOVERY_STATIC_HOSTS', [])
self.port = env.int('DISCOVERY_STATIC_PORT', 19530)
self.hosts = [socket.gethostbyname(host) for host in hosts]
self.hosts = [resolve_address(host, self.port) for host in hosts]
def start(self):
ok = True
for host in self.hosts:
self.add_pod(host, host)
ok &= self.add_pod(host, host)
if not ok: break
if ok and len(self.hosts) == 0:
logger.error('No address is specified')
ok = False
return ok
def stop(self):
for host in self.hosts:
self.delete_pod(host)
def add_pod(self, name, ip):
self.conn_mgr.register(name, 'tcp://{}:{}'.format(ip, self.port))
def add_pod(self, name, addr):
ok = True
status = StatusType.OK
try:
uri = 'tcp://{}'.format(addr)
status, group = self.readonly_topo.create(name=name)
if status == StatusType.OK:
status, pool = group.create(name=name, uri=uri)
if status not in (StatusType.OK, StatusType.DUPLICATED):
ok = False
except ConnectionConnectError as exc:
ok = False
logger.error('Connection error to: {}'.format(addr))
if ok and status == StatusType.OK:
logger.info('StaticDiscovery Add Static Group \"{}\" Of 1 Address: {}'.format(name, addr))
return ok
def delete_pod(self, name):
self.conn_mgr.unregister(name)
pool = self.readonly_topo.delete_group(name)
return True
@classmethod
def Create(cls, conn_mgr, plugin_config, **kwargs):
discovery = cls(config=plugin_config, conn_mgr=conn_mgr, **kwargs)
def Create(cls, readonly_topo, plugin_config, **kwargs):
discovery = cls(config=plugin_config, readonly_topo=readonly_topo, **kwargs)
return discovery
......
......@@ -3,6 +3,7 @@ DEBUG=True
WOSERVER=tcp://127.0.0.1:19530
SERVER_PORT=19535
SERVER_TEST_PORT=19888
MAX_WORKERS=50
#SQLALCHEMY_DATABASE_URI=mysql+pymysql://root:root@127.0.0.1:3306/milvus?charset=utf8mb4
SQLALCHEMY_DATABASE_URI=sqlite:////tmp/milvus/db/meta.sqlite?check_same_thread=False
......
......@@ -15,12 +15,14 @@ def create_app(testing_config=None):
pool_recycle=config.SQL_POOL_RECYCLE, pool_timeout=config.SQL_POOL_TIMEOUT,
pool_pre_ping=config.SQL_POOL_PRE_PING, max_overflow=config.SQL_MAX_OVERFLOW)
from mishards.connections import ConnectionMgr
connect_mgr = ConnectionMgr()
from mishards.connections import ConnectionMgr, ConnectionTopology
readonly_topo = ConnectionTopology()
writable_topo = ConnectionTopology()
from discovery.factory import DiscoveryFactory
discover = DiscoveryFactory(config.DISCOVERY_PLUGIN_PATH).create(config.DISCOVERY_CLASS_NAME,
conn_mgr=connect_mgr)
readonly_topo=readonly_topo)
from mishards.grpc_utils import GrpcSpanDecorator
from tracer.factory import TracerFactory
......@@ -30,12 +32,15 @@ def create_app(testing_config=None):
from mishards.router.factory import RouterFactory
router = RouterFactory(config.ROUTER_PLUGIN_PATH).create(config.ROUTER_CLASS_NAME,
conn_mgr=connect_mgr)
readonly_topo=readonly_topo,
writable_topo=writable_topo)
grpc_server.init_app(conn_mgr=connect_mgr,
grpc_server.init_app(writable_topo=writable_topo,
readonly_topo=readonly_topo,
tracer=tracer,
router=router,
discover=discover)
discover=discover,
max_workers=settings.MAX_WORKERS)
from mishards import exception_handlers
......
import logging
import threading
import enum
from functools import wraps
from milvus import Milvus
from milvus.client.hooks import BaseSearchHook
from mishards import (settings, exceptions)
from mishards import (settings, exceptions, topology)
from utils import singleton
logger = logging.getLogger(__name__)
......@@ -81,6 +82,140 @@ class Connection:
raise e
return inner
def __str__(self):
return '<Connection: {}:{}>'.format(self.name, id(self))
def __repr__(self):
return self.__str__()
class ProxyMixin:
def __getattr__(self, name):
target = self.__dict__.get(name, None)
if target or not self.connection:
return target
return getattr(self.connection, name)
class ScopedConnection(ProxyMixin):
def __init__(self, pool, connection):
self.pool = pool
self.connection = connection
def __del__(self):
self.release()
def __str__(self):
return self.connection.__str__()
def release(self):
if not self.pool or not self.connection:
return
self.pool.release(self.connection)
self.pool = None
self.connection = None
class ConnectionPool(topology.TopoObject):
def __init__(self, name, uri, max_retry=1, capacity=-1, **kwargs):
super().__init__(name)
self.capacity = capacity
self.pending_pool = set()
self.active_pool = set()
self.connection_ownership = {}
self.uri = uri
self.max_retry = max_retry
self.kwargs = kwargs
self.cv = threading.Condition()
def __len__(self):
return len(self.pending_pool) + len(self.active_pool)
@property
def active_num(self):
return len(self.active_pool)
def _is_full(self):
if self.capacity < 0:
return False
return len(self) >= self.capacity
def fetch(self, timeout=1):
with self.cv:
timeout_times = 0
while (len(self.pending_pool) == 0 and self._is_full() and timeout_times < 1):
self.cv.notifyAll()
self.cv.wait(timeout)
timeout_times += 1
connection = None
if timeout_times >= 1:
return connection
# logger.debug('[Connection] Pool \"{}\" SIZE={} ACTIVE={}'.format(self.name, len(self), self.active_num))
if len(self.pending_pool) == 0:
connection = self.create()
else:
connection = self.pending_pool.pop()
# logger.debug('[Connection] Registerring \"{}\" into pool \"{}\"'.format(connection, self.name))
self.active_pool.add(connection)
scoped_connection = ScopedConnection(self, connection)
return scoped_connection
def release(self, connection):
with self.cv:
if connection not in self.active_pool:
raise RuntimeError('\"{}\" not found in pool \"{}\"'.format(connection, self.name))
# logger.debug('[Connection] Releasing \"{}\" from pool \"{}\"'.format(connection, self.name))
# logger.debug('[Connection] Pool \"{}\" SIZE={} ACTIVE={}'.format(self.name, len(self), self.active_num))
self.active_pool.remove(connection)
self.pending_pool.add(connection)
def create(self):
connection = Connection(name=self.name, uri=self.uri, max_retry=self.max_retry, **self.kwargs)
return connection
class ConnectionGroup(topology.TopoGroup):
def __init__(self, name):
super().__init__(name)
def on_pre_add(self, topo_object):
conn = topo_object.fetch()
conn.on_connect(metadata=None)
status, version = conn.conn.server_version()
if not status.OK():
logger.error('Cannot connect to newly added address: {}. Remove it now'.format(topo_object.name))
return False
if version not in settings.SERVER_VERSIONS:
logger.error('Cannot connect to server of version: {}. Only {} supported'.format(version,
settings.SERVER_VERSIONS))
return False
return True
def create(self, name, **kwargs):
uri = kwargs.get('uri', None)
if not uri:
raise RuntimeError('\"uri\" is required to create connection pool')
pool = ConnectionPool(name=name, **kwargs)
status = self.add(pool)
if status != topology.StatusType.OK:
pool = None
return status, pool
class ConnectionTopology(topology.Topology):
def __init__(self):
super().__init__()
def create(self, name):
group = ConnectionGroup(name)
status = self.add_group(group)
if status == topology.StatusType.DUPLICATED:
group = None
return status, group
@singleton
class ConnectionMgr:
......@@ -126,6 +261,14 @@ class ConnectionMgr:
def on_new_meta(self, name, url):
logger.info('Register Connection: name={};url={}'.format(name, url))
self.metas[name] = url
conn = self.conn(name, metadata=None)
conn.on_connect(metadata=None)
status, _ = conn.conn.server_version()
if not status.OK():
logger.error('Cannot connect to newly added address: {}. Remove it now'.format(name))
self.unregister(name)
return False
return True
def on_duplicate_meta(self, name, url):
if self.metas[name] == url:
......@@ -135,19 +278,22 @@ class ConnectionMgr:
def on_same_meta(self, name, url):
# logger.warning('Register same meta: {}:{}'.format(name, url))
pass
return True
def on_diff_meta(self, name, url):
logger.warning('Received {} with diff url={}'.format(name, url))
self.metas[name] = url
self.conns[name] = {}
return True
def on_unregister_meta(self, name, url):
logger.info('Unregister name={};url={}'.format(name, url))
self.conns.pop(name, None)
return True
def on_nonexisted_meta(self, name):
logger.warning('Non-existed meta: {}'.format(name))
return False
def register(self, name, url):
meta = self.metas.get(name)
......
......@@ -2,20 +2,21 @@ from mishards import exceptions
class RouterMixin:
def __init__(self, conn_mgr):
self.conn_mgr = conn_mgr
def __init__(self, writable_topo, readonly_topo):
self.writable_topo = writable_topo
self.readonly_topo = readonly_topo
def routing(self, table_name, metadata=None, **kwargs):
raise NotImplemented()
def connection(self, metadata=None):
conn = self.conn_mgr.conn('WOSERVER', metadata=metadata)
conn = self.writable_topo.get_group('default').get('WOSERVER').fetch()
if conn:
conn.on_connect(metadata=metadata)
return conn.conn
def query_conn(self, name, metadata=None):
conn = self.conn_mgr.conn(name, metadata=metadata)
conn = self.readonly_topo.get_group(name).get(name).fetch()
if not conn:
raise exceptions.ConnectionNotFoundError(name, metadata=metadata)
conn.on_connect(metadata=metadata)
......
......@@ -12,8 +12,9 @@ logger = logging.getLogger(__name__)
class Factory(RouterMixin):
name = 'FileBasedHashRingRouter'
def __init__(self, conn_mgr, **kwargs):
super(Factory, self).__init__(conn_mgr)
def __init__(self, writable_topo, readonly_topo, **kwargs):
super(Factory, self).__init__(writable_topo=writable_topo,
readonly_topo=readonly_topo)
def routing(self, table_name, partition_tags=None, metadata=None, **kwargs):
range_array = kwargs.pop('range_array', None)
......@@ -46,7 +47,7 @@ class Factory(RouterMixin):
db.remove_session()
servers = self.conn_mgr.conn_names
servers = self.readonly_topo.group_names
logger.info('Available servers: {}'.format(servers))
ring = HashRing(servers)
......@@ -65,10 +66,13 @@ class Factory(RouterMixin):
@classmethod
def Create(cls, **kwargs):
conn_mgr = kwargs.pop('conn_mgr', None)
if not conn_mgr:
raise RuntimeError('Cannot find \'conn_mgr\' to initialize \'{}\''.format(self.name))
router = cls(conn_mgr, **kwargs)
writable_topo = kwargs.pop('writable_topo', None)
if not writable_topo:
raise RuntimeError('Cannot find \'writable_topo\' to initialize \'{}\''.format(self.name))
readonly_topo = kwargs.pop('readonly_topo', None)
if not readonly_topo:
raise RuntimeError('Cannot find \'readonly_topo\' to initialize \'{}\''.format(self.name))
router = cls(writable_topo=writable_topo, readonly_topo=readonly_topo, **kwargs)
return router
......
import logging
import sys
import grpc
import time
import socket
......@@ -23,7 +24,8 @@ class Server:
self.exit_flag = False
def init_app(self,
conn_mgr,
writable_topo,
readonly_topo,
tracer,
router,
discover,
......@@ -31,11 +33,14 @@ class Server:
max_workers=10,
**kwargs):
self.port = int(port)
self.conn_mgr = conn_mgr
self.writable_topo = writable_topo
self.readonly_topo = readonly_topo
self.tracer = tracer
self.router = router
self.discover = discover
logger.debug('Init grpc server with max_workers: {}'.format(max_workers))
self.server_impl = grpc.server(
thread_pool=futures.ThreadPoolExecutor(max_workers=max_workers),
options=[(cygrpc.ChannelArgKey.max_send_message_length, -1),
......@@ -50,8 +55,8 @@ class Server:
url = urlparse(woserver)
ip = socket.gethostbyname(url.hostname)
socket.inet_pton(socket.AF_INET, ip)
self.conn_mgr.register(
'WOSERVER', '{}://{}:{}'.format(url.scheme, ip, url.port or 80))
_, group = self.writable_topo.create('default')
group.create(name='WOSERVER', uri='{}://{}:{}'.format(url.scheme, ip, url.port or 80))
def register_pre_run_handler(self, func):
logger.info('Regiterring {} into server pre_run_handlers'.format(func))
......@@ -83,7 +88,7 @@ class Server:
def on_pre_run(self):
for handler in self.pre_run_handlers:
handler()
self.discover.start()
return self.discover.start()
def start(self, port=None):
handler_class = self.decorate_handler(ServiceHandler)
......@@ -97,7 +102,11 @@ class Server:
def run(self, port):
logger.info('Milvus server start ......')
port = port or self.port
self.on_pre_run()
ok = self.on_pre_run()
if not ok:
logger.error('Terminate server due to error found in on_pre_run')
sys.exit(1)
self.start(port)
logger.info('Listening on port {}'.format(port))
......
......@@ -12,6 +12,7 @@ else:
env.read_env()
SERVER_VERSIONS = ['0.6.0']
DEBUG = env.bool('DEBUG', False)
MAX_RETRY = env.int('MAX_RETRY', 3)
......@@ -26,6 +27,7 @@ config(LOG_LEVEL, LOG_PATH, LOG_NAME, TIMEZONE)
SERVER_PORT = env.int('SERVER_PORT', 19530)
SERVER_TEST_PORT = env.int('SERVER_TEST_PORT', 19530)
WOSERVER = env.str('WOSERVER')
MAX_WORKERS = env.int('MAX_WORKERS', 50)
class TracingConfig:
......
import logging
import pytest
import mock
import random
import threading
from milvus import Milvus
from mishards.connections import (ConnectionMgr, Connection)
from mishards.connections import (ConnectionMgr, Connection,
ConnectionPool, ConnectionTopology, ConnectionGroup)
from mishards.topology import StatusType
from mishards import exceptions
logger = logging.getLogger(__name__)
......@@ -11,6 +15,7 @@ logger = logging.getLogger(__name__)
@pytest.mark.usefixtures('app')
class TestConnection:
@pytest.mark.skip
def test_manager(self):
mgr = ConnectionMgr()
......@@ -99,3 +104,161 @@ class TestConnection:
this_connect = c.connect(func=None, exception_handler=error_handler)
this_connect()
assert len(errors) == 1
def test_topology(self):
ConnectionGroup.on_pre_add = mock.MagicMock(return_value=(True,))
w_topo = ConnectionTopology()
status, wg1 = w_topo.create(name='wg1')
assert w_topo.has_group(wg1)
assert status == StatusType.OK
status, wg1_dup = w_topo.create(name='wg1')
assert wg1_dup is None
assert status == StatusType.DUPLICATED
fetched_group = w_topo.get_group('wg1')
assert id(fetched_group) == id(wg1)
with pytest.raises(RuntimeError):
wg1.create(name='wg1_p1')
status, wg1_p1 = wg1.create(name='wg1_p1', uri='127.0.0.1:19530')
assert status == StatusType.OK
assert wg1_p1 is not None
assert len(wg1) == 1
status, wg1_p1_dup = wg1.create(name='wg1_p1', uri='127.0.0.1:19530')
assert status == StatusType.DUPLICATED
assert wg1_p1_dup is None
assert len(wg1) == 1
status, wg1_p2 = wg1.create('wg1_p2', uri='127.0.0.1:19530')
assert status == StatusType.OK
assert wg1_p2 is not None
assert len(wg1) == 2
poped = wg1.remove('wg1_p3')
assert poped is None
assert len(wg1) == 2
poped = wg1.remove('wg1_p2')
assert poped.name == 'wg1_p2'
assert len(wg1) == 1
fetched_p1 = wg1.get(wg1_p1.name)
assert fetched_p1 == wg1_p1
fetched_p1 = w_topo.get_group('wg1').get('wg1_p1')
conn1 = fetched_p1.fetch()
assert len(fetched_p1) == 1
assert fetched_p1.active_num == 1
conn2 = fetched_p1.fetch()
assert len(fetched_p1) == 2
assert fetched_p1.active_num == 2
conn2.release()
assert len(fetched_p1) == 2
assert fetched_p1.active_num == 1
assert len(w_topo.group_names) == 1
def test_connection_pool(self):
ConnectionGroup.on_pre_add = mock.MagicMock(return_value=(True,))
def choaz_mp_fetch(capacity, count, tnum):
threads_num = 5
topo = ConnectionTopology()
_, tg = topo.create('tg')
pool_size = 20
pool_names = ['p{}:19530'.format(i) for i in range(pool_size)]
threads = []
def Worker(group, cnt, capacity):
ori_cnt = cnt
assert cnt < 100
while cnt >= 0:
name = pool_names[random.randint(0, pool_size-1)]
cnt -= 1
remove = (random.randint(1,4)%4 == 0)
if remove:
pool = group.get(name=name)
# if name.startswith("p1:"):
# logger.error('{} CNT={} [Remove] Group \"{}\" has pool of SIZE={} ACTIVE={}'.format(threading.get_ident(), ori_cnt-cnt, name, len(pool), pool.active_num))
group.remove(name)
else:
group.create(name=name, uri=name, capacity=capacity)
pool = group.get(name=name)
assert pool is not None
conn = pool.fetch(timeout=0.01)
# if name.startswith("p1:"):
# logger.error('{} CNT={} [Adding] Group \"{}\" has pool of SIZE={} ACTIVE={}'.format(threading.get_ident(), ori_cnt-cnt, name, len(pool), pool.active_num))
for _ in range(threads_num):
t = threading.Thread(target=Worker, args=(tg, count, tnum))
threads.append(t)
t.start()
for t in threads:
t.join()
choaz_mp_fetch(4, 40, 8)
def check_mp_fetch(capacity=-1):
w2 = ConnectionPool(name='w2', uri='127.0.0.1:19530', max_retry=2, capacity=capacity)
connections = []
def GetConnection(pool):
conn = pool.fetch(timeout=0.1)
if conn:
connections.append(conn)
threads = []
threads_num = 10 if capacity < 0 else 2*capacity
for _ in range(threads_num):
t = threading.Thread(target=GetConnection, args=(w2,))
threads.append(t)
t.start()
for t in threads:
t.join()
expected_size = threads_num if capacity < 0 else capacity
assert len(connections) == expected_size
check_mp_fetch(5)
check_mp_fetch()
w1 = ConnectionPool(name='w1', uri='127.0.0.1:19530', max_retry=2, capacity=2)
w1_1 = w1.fetch()
assert len(w1) == 1
assert w1.active_num == 1
w1_2 = w1.fetch()
assert len(w1) == 2
assert w1.active_num == 2
w1_3 = w1.fetch()
assert w1_3 is None
assert len(w1) == 2
assert w1.active_num == 2
w1_1.release()
assert len(w1) == 2
assert w1.active_num == 1
def check(pool, expected_size, expected_active_num):
w = pool.fetch()
assert len(pool) == expected_size
assert pool.active_num == expected_active_num
check(w1, 2, 2)
assert len(w1) == 2
assert w1.active_num == 1
wild_w = w1.create()
with pytest.raises(RuntimeError):
w1.release(wild_w)
ret = w1_2.can_retry
assert ret == w1_2.connection.can_retry
......@@ -14,6 +14,8 @@ from mishards.service_handler import ServiceHandler
from mishards.grpc_utils.grpc_args_parser import GrpcArgsParser as Parser
from mishards.factories import TableFilesFactory, TablesFactory, TableFiles, Tables
from mishards.router import RouterMixin
from mishards.connections import (ConnectionMgr, Connection,
ConnectionPool, ConnectionTopology, ConnectionGroup)
logger = logging.getLogger(__name__)
......@@ -23,15 +25,13 @@ BAD = Status(code=Status.PERMISSION_DENIED, message='Fail')
@pytest.mark.usefixtures('started_app')
class TestServer:
@property
def client(self):
m = Milvus()
m.connect(host='localhost', port=settings.SERVER_TEST_PORT)
return m
def test_server_start(self, started_app):
assert started_app.conn_mgr.metas.get('WOSERVER') == settings.WOSERVER
def test_cmd(self, started_app):
ServiceHandler._get_server_version = mock.MagicMock(return_value=(OK,
''))
......@@ -228,6 +228,7 @@ class TestServer:
def random_data(self, n, dimension):
return [[random.random() for _ in range(dimension)] for _ in range(n)]
@pytest.mark.skip
def test_search(self, started_app):
table_name = inspect.currentframe().f_code.co_name
to_index_cnt = random.randint(10, 20)
......
import logging
import threading
import enum
logger = logging.getLogger(__name__)
class TopoObject:
def __init__(self, name, **kwargs):
self.name = name
self.kwargs = kwargs
def __eq__(self, other):
if isinstance(other, str):
return self.name == other
return self.name == other.name
def __hash__(self):
return hash(self.name)
def __str__(self):
return '<TopoObject: {}>'.format(self.name)
class StatusType(enum.Enum):
OK = 1
DUPLICATED = 2
ADD_ERROR = 3
VERSION_ERROR = 4
class TopoGroup:
def __init__(self, name):
self.name = name
self.items = {}
self.cv = threading.Condition()
def on_duplicate(self, topo_object):
logger.warning('Duplicated topo_object \"{}\" into group \"{}\"'.format(topo_object, self.name))
def on_added(self, topo_object):
return True
def on_pre_add(self, topo_object):
return True
def _add_no_lock(self, topo_object):
if topo_object.name in self.items:
return StatusType.DUPLICATED
logger.info('Adding topo_object \"{}\" into group \"{}\"'.format(topo_object, self.name))
ok = self.on_pre_add(topo_object)
if not ok:
return StatusType.VERSION_ERROR
self.items[topo_object.name] = topo_object
ok = self.on_added(topo_object)
if not ok:
self._remove_no_lock(topo_object.name)
return StatusType.OK if ok else StatusType.ADD_ERROR
def add(self, topo_object):
with self.cv:
return self._add_no_lock(topo_object)
def __len__(self):
return len(self.items)
def __str__(self):
return '<TopoGroup: {}>'.format(self.name)
def get(self, name):
return self.items.get(name, None)
def _remove_no_lock(self, name):
logger.info('Removing topo_object \"{}\" from group \"{}\"'.format(name, self.name))
return self.items.pop(name, None)
def remove(self, name):
with self.cv:
return self._remove_no_lock(name)
class Topology:
def __init__(self):
self.topo_groups = {}
self.cv = threading.Condition()
def on_duplicated_group(self, group):
logger.warning('Duplicated group \"{}\" found!'.format(group))
return StatusType.DUPLICATED
def on_pre_add_group(self, group):
logger.debug('Pre add group \"{}\"'.format(group))
return StatusType.OK
def on_post_add_group(self, group):
logger.debug('Post add group \"{}\"'.format(group))
return StatusType.OK
def get_group(self, name):
return self.topo_groups.get(name, None)
def has_group(self, group):
key = group if isinstance(group, str) else group.name
return key in self.topo_groups
def _add_group_no_lock(self, group):
logger.info('Adding group \"{}\"'.format(group))
self.topo_groups[group.name] = group
def add_group(self, group):
self.on_pre_add_group(group)
if self.has_group(group):
return self.on_duplicated_group(group)
with self.cv:
self._add_group_no_lock(group)
return self.on_post_add_group(group)
def on_delete_not_existed_group(self, group):
logger.warning('Deleting non-existed group \"{}\"'.format(group))
def on_pre_delete_group(self, group):
logger.debug('Pre delete group \"{}\"'.format(group))
def on_post_delete_group(self, group):
logger.debug('Post delete group \"{}\"'.format(group))
def _delete_group_no_lock(self, group):
logger.info('Deleting group \"{}\"'.format(group))
delete_key = group if isinstance(group, str) else group.name
return self.topo_groups.pop(delete_key, None)
def delete_group(self, group):
self.on_pre_delete_group(group)
with self.cv:
deleted_group = self._delete_group_lock(group)
if not deleted_group:
return self.on_delete_not_existed_group(group)
return self.on_post_delete_group(group)
@property
def group_names(self):
return self.topo_groups.keys()
......@@ -26,6 +26,8 @@ class JaegerFactory:
tracer,
log_payloads=plugin_config.TRACING_LOG_PAYLOAD,
span_decorator=span_decorator)
jaeger_logger = logging.getLogger('jaeger_tracing')
jaeger_logger.setLevel(logging.ERROR)
return Tracer(tracer, tracer_interceptor, intercept_server)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册