From e29666cea97e9087e32666a56f5957aa101839a0 Mon Sep 17 00:00:00 2001 From: Andrew Svetlov Date: Thu, 19 Nov 2015 13:53:38 +0200 Subject: [PATCH] Convert test_no_auth into pytest style --- tests/conftest.py | 139 ++++++++++++++++++++++++++++++++++++++++++ tests/test_no_auth.py | 99 +++++++++--------------------- 2 files changed, 168 insertions(+), 70 deletions(-) create mode 100644 tests/conftest.py diff --git a/tests/conftest.py b/tests/conftest.py new file mode 100644 index 0000000..180a477 --- /dev/null +++ b/tests/conftest.py @@ -0,0 +1,139 @@ +import aiohttp +import asyncio +import gc +import pytest +import socket +from aiohttp import web + + +@pytest.fixture +def unused_port(): + def f(): + with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s: + s.bind(('127.0.0.1', 0)) + return s.getsockname()[1] + return f + + +@pytest.yield_fixture +def loop(request): + loop = asyncio.new_event_loop() + asyncio.set_event_loop(None) + + yield loop + + loop.stop() + loop.run_forever() + loop.close() + gc.collect() + asyncio.set_event_loop(None) + + +@pytest.yield_fixture +def create_server(loop, unused_port): + app = handler = srv = None + + @asyncio.coroutine + def create(*, debug=False, ssl_ctx=None, proto='http'): + nonlocal app, handler, srv + app = web.Application(loop=loop) + port = unused_port() + handler = app.make_handler(debug=debug, keep_alive_on=False) + srv = yield from loop.create_server(handler, '127.0.0.1', port, + ssl=ssl_ctx) + if ssl_ctx: + proto += 's' + url = "{}://127.0.0.1:{}".format(proto, port) + return app, url + + yield create + + @asyncio.coroutine + def finish(): + yield from handler.finish_connections() + yield from app.finish() + srv.close() + yield from srv.wait_closed() + + loop.run_until_complete(finish()) + + +class Client: + def __init__(self, session, url): + self._session = session + if not url.endswith('/'): + url += '/' + self._url = url + + def close(self): + self._session.close() + + def get(self, path, **kwargs): + while path.startswith('/'): + path = path[1:] + url = self._url + path + return self._session.get(url, **kwargs) + + def post(self, path, **kwargs): + while path.startswith('/'): + path = path[1:] + url = self._url + path + return self._session.post(url, **kwargs) + + def ws_connect(self, path, **kwargs): + while path.startswith('/'): + path = path[1:] + url = self._url + path + return self._session.ws_connect(url, **kwargs) + + +@pytest.yield_fixture +def create_app_and_client(create_server, loop): + client = None + + @asyncio.coroutine + def maker(*, server_params=None, client_params=None): + nonlocal client + if server_params is None: + server_params = {} + server_params.setdefault('debug', False) + server_params.setdefault('ssl_ctx', None) + app, url = yield from create_server(**server_params) + if client_params is None: + client_params = {} + client = Client(aiohttp.ClientSession(loop=loop, **client_params), url) + return app, client + + yield maker + client.close() + + +@pytest.mark.tryfirst +def pytest_pycollect_makeitem(collector, name, obj): + if collector.funcnamefilter(name): + if not callable(obj): + return + item = pytest.Function(name, parent=collector) + if 'run_loop' in item.keywords: + return list(collector._genfunctions(name, obj)) + + +@pytest.mark.tryfirst +def pytest_pyfunc_call(pyfuncitem): + """ + Run asyncio marked test functions in an event loop instead of a normal + function call. + """ + if 'run_loop' in pyfuncitem.keywords: + funcargs = pyfuncitem.funcargs + loop = funcargs['loop'] + testargs = {arg: funcargs[arg] + for arg in pyfuncitem._fixtureinfo.argnames} + loop.run_until_complete(pyfuncitem.obj(**testargs)) + return True + + +def pytest_runtest_setup(item): + if 'run_loop' in item.keywords and 'loop' not in item.fixturenames: + # inject an event loop fixture for all async tests + item.fixturenames.append('loop') diff --git a/tests/test_no_auth.py b/tests/test_no_auth.py index ca3c04a..8c3d99e 100644 --- a/tests/test_no_auth.py +++ b/tests/test_no_auth.py @@ -1,82 +1,41 @@ import asyncio -import socket -import unittest +import pytest -import aiohttp from aiohttp import web -from aiohttp_security import (authorized_userid, permits) +from aiohttp_security import authorized_userid, permits -class TestNoAuth(unittest.TestCase): - - def setUp(self): - self.loop = asyncio.new_event_loop() - asyncio.set_event_loop(None) - self.client = aiohttp.ClientSession(loop=self.loop) - - def tearDown(self): - self.client.close() - self.loop.run_until_complete(self.handler.finish_connections()) - self.srv.close() - self.loop.run_until_complete(self.srv.wait_closed()) - self.loop.close() - - def find_unused_port(self): - s = socket.socket(socket.AF_INET, socket.SOCK_STREAM) - s.bind(('127.0.0.1', 0)) - port = s.getsockname()[1] - s.close() - return port +@pytest.mark.run_loop +def test_authorized_userid(create_app_and_client): @asyncio.coroutine - def create_server(self): - app = web.Application(loop=self.loop) + def check(request): + userid = yield from authorized_userid(request) + assert userid is None + return web.Response() - port = self.find_unused_port() - self.handler = app.make_handler( - debug=False, keep_alive_on=False) - srv = yield from self.loop.create_server( - self.handler, '127.0.0.1', port) - url = "http://127.0.0.1:{}/".format(port) - self.srv = srv - return app, srv, url + app, client = yield from create_app_and_client() + app.router.add_route('GET', '/', check) + resp = yield from client.get('/') + assert 200 == resp.status + yield from resp.release() - def test_authorized_userid(self): - @asyncio.coroutine - def check(request): - userid = yield from authorized_userid(request) - self.assertIsNone(userid) - return web.Response() +@pytest.mark.run_loop +def test_permits(create_app_and_client): - @asyncio.coroutine - def go(): - app, srv, url = yield from self.create_server() - app.router.add_route('GET', '/', check) - resp = yield from self.client.get(url) - self.assertEqual(200, resp.status) - yield from resp.release() + @asyncio.coroutine + def check(request): + ret = yield from permits(request, 'read') + assert ret + ret = yield from permits(request, 'write') + assert ret + ret = yield from permits(request, 'unknown') + assert ret + return web.Response() - self.loop.run_until_complete(go()) - - def test_permits(self): - - @asyncio.coroutine - def check(request): - ret = yield from permits(request, 'read') - self.assertTrue(ret) - ret = yield from permits(request, 'write') - self.assertTrue(ret) - ret = yield from permits(request, 'unknown') - self.assertTrue(ret) - return web.Response() - - @asyncio.coroutine - def go(): - app, srv, url = yield from self.create_server() - app.router.add_route('GET', '/', check) - resp = yield from self.client.get(url) - self.assertEqual(200, resp.status) - yield from resp.release() - - self.loop.run_until_complete(go()) + app, client = yield from create_app_and_client() + app.router.add_route('GET', '/', check) + resp = yield from client.get('/') + assert 200 == resp.status + yield from resp.release()