Convert test_no_auth into pytest style
This commit is contained in:
parent
c379fb4beb
commit
e29666cea9
|
@ -0,0 +1,139 @@
|
||||||
|
import aiohttp
|
||||||
|
import asyncio
|
||||||
|
import gc
|
||||||
|
import pytest
|
||||||
|
import socket
|
||||||
|
from aiohttp import web
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def unused_port():
|
||||||
|
def f():
|
||||||
|
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
|
||||||
|
s.bind(('127.0.0.1', 0))
|
||||||
|
return s.getsockname()[1]
|
||||||
|
return f
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.yield_fixture
|
||||||
|
def loop(request):
|
||||||
|
loop = asyncio.new_event_loop()
|
||||||
|
asyncio.set_event_loop(None)
|
||||||
|
|
||||||
|
yield loop
|
||||||
|
|
||||||
|
loop.stop()
|
||||||
|
loop.run_forever()
|
||||||
|
loop.close()
|
||||||
|
gc.collect()
|
||||||
|
asyncio.set_event_loop(None)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.yield_fixture
|
||||||
|
def create_server(loop, unused_port):
|
||||||
|
app = handler = srv = None
|
||||||
|
|
||||||
|
@asyncio.coroutine
|
||||||
|
def create(*, debug=False, ssl_ctx=None, proto='http'):
|
||||||
|
nonlocal app, handler, srv
|
||||||
|
app = web.Application(loop=loop)
|
||||||
|
port = unused_port()
|
||||||
|
handler = app.make_handler(debug=debug, keep_alive_on=False)
|
||||||
|
srv = yield from loop.create_server(handler, '127.0.0.1', port,
|
||||||
|
ssl=ssl_ctx)
|
||||||
|
if ssl_ctx:
|
||||||
|
proto += 's'
|
||||||
|
url = "{}://127.0.0.1:{}".format(proto, port)
|
||||||
|
return app, url
|
||||||
|
|
||||||
|
yield create
|
||||||
|
|
||||||
|
@asyncio.coroutine
|
||||||
|
def finish():
|
||||||
|
yield from handler.finish_connections()
|
||||||
|
yield from app.finish()
|
||||||
|
srv.close()
|
||||||
|
yield from srv.wait_closed()
|
||||||
|
|
||||||
|
loop.run_until_complete(finish())
|
||||||
|
|
||||||
|
|
||||||
|
class Client:
|
||||||
|
def __init__(self, session, url):
|
||||||
|
self._session = session
|
||||||
|
if not url.endswith('/'):
|
||||||
|
url += '/'
|
||||||
|
self._url = url
|
||||||
|
|
||||||
|
def close(self):
|
||||||
|
self._session.close()
|
||||||
|
|
||||||
|
def get(self, path, **kwargs):
|
||||||
|
while path.startswith('/'):
|
||||||
|
path = path[1:]
|
||||||
|
url = self._url + path
|
||||||
|
return self._session.get(url, **kwargs)
|
||||||
|
|
||||||
|
def post(self, path, **kwargs):
|
||||||
|
while path.startswith('/'):
|
||||||
|
path = path[1:]
|
||||||
|
url = self._url + path
|
||||||
|
return self._session.post(url, **kwargs)
|
||||||
|
|
||||||
|
def ws_connect(self, path, **kwargs):
|
||||||
|
while path.startswith('/'):
|
||||||
|
path = path[1:]
|
||||||
|
url = self._url + path
|
||||||
|
return self._session.ws_connect(url, **kwargs)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.yield_fixture
|
||||||
|
def create_app_and_client(create_server, loop):
|
||||||
|
client = None
|
||||||
|
|
||||||
|
@asyncio.coroutine
|
||||||
|
def maker(*, server_params=None, client_params=None):
|
||||||
|
nonlocal client
|
||||||
|
if server_params is None:
|
||||||
|
server_params = {}
|
||||||
|
server_params.setdefault('debug', False)
|
||||||
|
server_params.setdefault('ssl_ctx', None)
|
||||||
|
app, url = yield from create_server(**server_params)
|
||||||
|
if client_params is None:
|
||||||
|
client_params = {}
|
||||||
|
client = Client(aiohttp.ClientSession(loop=loop, **client_params), url)
|
||||||
|
return app, client
|
||||||
|
|
||||||
|
yield maker
|
||||||
|
client.close()
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.tryfirst
|
||||||
|
def pytest_pycollect_makeitem(collector, name, obj):
|
||||||
|
if collector.funcnamefilter(name):
|
||||||
|
if not callable(obj):
|
||||||
|
return
|
||||||
|
item = pytest.Function(name, parent=collector)
|
||||||
|
if 'run_loop' in item.keywords:
|
||||||
|
return list(collector._genfunctions(name, obj))
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.tryfirst
|
||||||
|
def pytest_pyfunc_call(pyfuncitem):
|
||||||
|
"""
|
||||||
|
Run asyncio marked test functions in an event loop instead of a normal
|
||||||
|
function call.
|
||||||
|
"""
|
||||||
|
if 'run_loop' in pyfuncitem.keywords:
|
||||||
|
funcargs = pyfuncitem.funcargs
|
||||||
|
loop = funcargs['loop']
|
||||||
|
testargs = {arg: funcargs[arg]
|
||||||
|
for arg in pyfuncitem._fixtureinfo.argnames}
|
||||||
|
loop.run_until_complete(pyfuncitem.obj(**testargs))
|
||||||
|
return True
|
||||||
|
|
||||||
|
|
||||||
|
def pytest_runtest_setup(item):
|
||||||
|
if 'run_loop' in item.keywords and 'loop' not in item.fixturenames:
|
||||||
|
# inject an event loop fixture for all async tests
|
||||||
|
item.fixturenames.append('loop')
|
|
@ -1,82 +1,41 @@
|
||||||
import asyncio
|
import asyncio
|
||||||
import socket
|
import pytest
|
||||||
import unittest
|
|
||||||
|
|
||||||
import aiohttp
|
|
||||||
from aiohttp import web
|
from aiohttp import web
|
||||||
from aiohttp_security import (authorized_userid, permits)
|
from aiohttp_security import authorized_userid, permits
|
||||||
|
|
||||||
|
|
||||||
class TestNoAuth(unittest.TestCase):
|
@pytest.mark.run_loop
|
||||||
|
def test_authorized_userid(create_app_and_client):
|
||||||
def setUp(self):
|
|
||||||
self.loop = asyncio.new_event_loop()
|
|
||||||
asyncio.set_event_loop(None)
|
|
||||||
self.client = aiohttp.ClientSession(loop=self.loop)
|
|
||||||
|
|
||||||
def tearDown(self):
|
|
||||||
self.client.close()
|
|
||||||
self.loop.run_until_complete(self.handler.finish_connections())
|
|
||||||
self.srv.close()
|
|
||||||
self.loop.run_until_complete(self.srv.wait_closed())
|
|
||||||
self.loop.close()
|
|
||||||
|
|
||||||
def find_unused_port(self):
|
|
||||||
s = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
|
|
||||||
s.bind(('127.0.0.1', 0))
|
|
||||||
port = s.getsockname()[1]
|
|
||||||
s.close()
|
|
||||||
return port
|
|
||||||
|
|
||||||
@asyncio.coroutine
|
|
||||||
def create_server(self):
|
|
||||||
app = web.Application(loop=self.loop)
|
|
||||||
|
|
||||||
port = self.find_unused_port()
|
|
||||||
self.handler = app.make_handler(
|
|
||||||
debug=False, keep_alive_on=False)
|
|
||||||
srv = yield from self.loop.create_server(
|
|
||||||
self.handler, '127.0.0.1', port)
|
|
||||||
url = "http://127.0.0.1:{}/".format(port)
|
|
||||||
self.srv = srv
|
|
||||||
return app, srv, url
|
|
||||||
|
|
||||||
def test_authorized_userid(self):
|
|
||||||
|
|
||||||
@asyncio.coroutine
|
@asyncio.coroutine
|
||||||
def check(request):
|
def check(request):
|
||||||
userid = yield from authorized_userid(request)
|
userid = yield from authorized_userid(request)
|
||||||
self.assertIsNone(userid)
|
assert userid is None
|
||||||
return web.Response()
|
return web.Response()
|
||||||
|
|
||||||
@asyncio.coroutine
|
app, client = yield from create_app_and_client()
|
||||||
def go():
|
|
||||||
app, srv, url = yield from self.create_server()
|
|
||||||
app.router.add_route('GET', '/', check)
|
app.router.add_route('GET', '/', check)
|
||||||
resp = yield from self.client.get(url)
|
resp = yield from client.get('/')
|
||||||
self.assertEqual(200, resp.status)
|
assert 200 == resp.status
|
||||||
yield from resp.release()
|
yield from resp.release()
|
||||||
|
|
||||||
self.loop.run_until_complete(go())
|
|
||||||
|
|
||||||
def test_permits(self):
|
@pytest.mark.run_loop
|
||||||
|
def test_permits(create_app_and_client):
|
||||||
|
|
||||||
@asyncio.coroutine
|
@asyncio.coroutine
|
||||||
def check(request):
|
def check(request):
|
||||||
ret = yield from permits(request, 'read')
|
ret = yield from permits(request, 'read')
|
||||||
self.assertTrue(ret)
|
assert ret
|
||||||
ret = yield from permits(request, 'write')
|
ret = yield from permits(request, 'write')
|
||||||
self.assertTrue(ret)
|
assert ret
|
||||||
ret = yield from permits(request, 'unknown')
|
ret = yield from permits(request, 'unknown')
|
||||||
self.assertTrue(ret)
|
assert ret
|
||||||
return web.Response()
|
return web.Response()
|
||||||
|
|
||||||
@asyncio.coroutine
|
app, client = yield from create_app_and_client()
|
||||||
def go():
|
|
||||||
app, srv, url = yield from self.create_server()
|
|
||||||
app.router.add_route('GET', '/', check)
|
app.router.add_route('GET', '/', check)
|
||||||
resp = yield from self.client.get(url)
|
resp = yield from client.get('/')
|
||||||
self.assertEqual(200, resp.status)
|
assert 200 == resp.status
|
||||||
yield from resp.release()
|
yield from resp.release()
|
||||||
|
|
||||||
self.loop.run_until_complete(go())
|
|
||||||
|
|
Loading…
Reference in New Issue