Convert test_no_auth into pytest style
This commit is contained in:
parent
c379fb4beb
commit
e29666cea9
|
@ -0,0 +1,139 @@
|
|||
import aiohttp
|
||||
import asyncio
|
||||
import gc
|
||||
import pytest
|
||||
import socket
|
||||
from aiohttp import web
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def unused_port():
|
||||
def f():
|
||||
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
|
||||
s.bind(('127.0.0.1', 0))
|
||||
return s.getsockname()[1]
|
||||
return f
|
||||
|
||||
|
||||
@pytest.yield_fixture
|
||||
def loop(request):
|
||||
loop = asyncio.new_event_loop()
|
||||
asyncio.set_event_loop(None)
|
||||
|
||||
yield loop
|
||||
|
||||
loop.stop()
|
||||
loop.run_forever()
|
||||
loop.close()
|
||||
gc.collect()
|
||||
asyncio.set_event_loop(None)
|
||||
|
||||
|
||||
@pytest.yield_fixture
|
||||
def create_server(loop, unused_port):
|
||||
app = handler = srv = None
|
||||
|
||||
@asyncio.coroutine
|
||||
def create(*, debug=False, ssl_ctx=None, proto='http'):
|
||||
nonlocal app, handler, srv
|
||||
app = web.Application(loop=loop)
|
||||
port = unused_port()
|
||||
handler = app.make_handler(debug=debug, keep_alive_on=False)
|
||||
srv = yield from loop.create_server(handler, '127.0.0.1', port,
|
||||
ssl=ssl_ctx)
|
||||
if ssl_ctx:
|
||||
proto += 's'
|
||||
url = "{}://127.0.0.1:{}".format(proto, port)
|
||||
return app, url
|
||||
|
||||
yield create
|
||||
|
||||
@asyncio.coroutine
|
||||
def finish():
|
||||
yield from handler.finish_connections()
|
||||
yield from app.finish()
|
||||
srv.close()
|
||||
yield from srv.wait_closed()
|
||||
|
||||
loop.run_until_complete(finish())
|
||||
|
||||
|
||||
class Client:
|
||||
def __init__(self, session, url):
|
||||
self._session = session
|
||||
if not url.endswith('/'):
|
||||
url += '/'
|
||||
self._url = url
|
||||
|
||||
def close(self):
|
||||
self._session.close()
|
||||
|
||||
def get(self, path, **kwargs):
|
||||
while path.startswith('/'):
|
||||
path = path[1:]
|
||||
url = self._url + path
|
||||
return self._session.get(url, **kwargs)
|
||||
|
||||
def post(self, path, **kwargs):
|
||||
while path.startswith('/'):
|
||||
path = path[1:]
|
||||
url = self._url + path
|
||||
return self._session.post(url, **kwargs)
|
||||
|
||||
def ws_connect(self, path, **kwargs):
|
||||
while path.startswith('/'):
|
||||
path = path[1:]
|
||||
url = self._url + path
|
||||
return self._session.ws_connect(url, **kwargs)
|
||||
|
||||
|
||||
@pytest.yield_fixture
|
||||
def create_app_and_client(create_server, loop):
|
||||
client = None
|
||||
|
||||
@asyncio.coroutine
|
||||
def maker(*, server_params=None, client_params=None):
|
||||
nonlocal client
|
||||
if server_params is None:
|
||||
server_params = {}
|
||||
server_params.setdefault('debug', False)
|
||||
server_params.setdefault('ssl_ctx', None)
|
||||
app, url = yield from create_server(**server_params)
|
||||
if client_params is None:
|
||||
client_params = {}
|
||||
client = Client(aiohttp.ClientSession(loop=loop, **client_params), url)
|
||||
return app, client
|
||||
|
||||
yield maker
|
||||
client.close()
|
||||
|
||||
|
||||
@pytest.mark.tryfirst
|
||||
def pytest_pycollect_makeitem(collector, name, obj):
|
||||
if collector.funcnamefilter(name):
|
||||
if not callable(obj):
|
||||
return
|
||||
item = pytest.Function(name, parent=collector)
|
||||
if 'run_loop' in item.keywords:
|
||||
return list(collector._genfunctions(name, obj))
|
||||
|
||||
|
||||
@pytest.mark.tryfirst
|
||||
def pytest_pyfunc_call(pyfuncitem):
|
||||
"""
|
||||
Run asyncio marked test functions in an event loop instead of a normal
|
||||
function call.
|
||||
"""
|
||||
if 'run_loop' in pyfuncitem.keywords:
|
||||
funcargs = pyfuncitem.funcargs
|
||||
loop = funcargs['loop']
|
||||
testargs = {arg: funcargs[arg]
|
||||
for arg in pyfuncitem._fixtureinfo.argnames}
|
||||
loop.run_until_complete(pyfuncitem.obj(**testargs))
|
||||
return True
|
||||
|
||||
|
||||
def pytest_runtest_setup(item):
|
||||
if 'run_loop' in item.keywords and 'loop' not in item.fixturenames:
|
||||
# inject an event loop fixture for all async tests
|
||||
item.fixturenames.append('loop')
|
|
@ -1,82 +1,41 @@
|
|||
import asyncio
|
||||
import socket
|
||||
import unittest
|
||||
import pytest
|
||||
|
||||
import aiohttp
|
||||
from aiohttp import web
|
||||
from aiohttp_security import (authorized_userid, permits)
|
||||
from aiohttp_security import authorized_userid, permits
|
||||
|
||||
|
||||
class TestNoAuth(unittest.TestCase):
|
||||
|
||||
def setUp(self):
|
||||
self.loop = asyncio.new_event_loop()
|
||||
asyncio.set_event_loop(None)
|
||||
self.client = aiohttp.ClientSession(loop=self.loop)
|
||||
|
||||
def tearDown(self):
|
||||
self.client.close()
|
||||
self.loop.run_until_complete(self.handler.finish_connections())
|
||||
self.srv.close()
|
||||
self.loop.run_until_complete(self.srv.wait_closed())
|
||||
self.loop.close()
|
||||
|
||||
def find_unused_port(self):
|
||||
s = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
|
||||
s.bind(('127.0.0.1', 0))
|
||||
port = s.getsockname()[1]
|
||||
s.close()
|
||||
return port
|
||||
|
||||
@asyncio.coroutine
|
||||
def create_server(self):
|
||||
app = web.Application(loop=self.loop)
|
||||
|
||||
port = self.find_unused_port()
|
||||
self.handler = app.make_handler(
|
||||
debug=False, keep_alive_on=False)
|
||||
srv = yield from self.loop.create_server(
|
||||
self.handler, '127.0.0.1', port)
|
||||
url = "http://127.0.0.1:{}/".format(port)
|
||||
self.srv = srv
|
||||
return app, srv, url
|
||||
|
||||
def test_authorized_userid(self):
|
||||
@pytest.mark.run_loop
|
||||
def test_authorized_userid(create_app_and_client):
|
||||
|
||||
@asyncio.coroutine
|
||||
def check(request):
|
||||
userid = yield from authorized_userid(request)
|
||||
self.assertIsNone(userid)
|
||||
assert userid is None
|
||||
return web.Response()
|
||||
|
||||
@asyncio.coroutine
|
||||
def go():
|
||||
app, srv, url = yield from self.create_server()
|
||||
app, client = yield from create_app_and_client()
|
||||
app.router.add_route('GET', '/', check)
|
||||
resp = yield from self.client.get(url)
|
||||
self.assertEqual(200, resp.status)
|
||||
resp = yield from client.get('/')
|
||||
assert 200 == resp.status
|
||||
yield from resp.release()
|
||||
|
||||
self.loop.run_until_complete(go())
|
||||
|
||||
def test_permits(self):
|
||||
@pytest.mark.run_loop
|
||||
def test_permits(create_app_and_client):
|
||||
|
||||
@asyncio.coroutine
|
||||
def check(request):
|
||||
ret = yield from permits(request, 'read')
|
||||
self.assertTrue(ret)
|
||||
assert ret
|
||||
ret = yield from permits(request, 'write')
|
||||
self.assertTrue(ret)
|
||||
assert ret
|
||||
ret = yield from permits(request, 'unknown')
|
||||
self.assertTrue(ret)
|
||||
assert ret
|
||||
return web.Response()
|
||||
|
||||
@asyncio.coroutine
|
||||
def go():
|
||||
app, srv, url = yield from self.create_server()
|
||||
app, client = yield from create_app_and_client()
|
||||
app.router.add_route('GET', '/', check)
|
||||
resp = yield from self.client.get(url)
|
||||
self.assertEqual(200, resp.status)
|
||||
resp = yield from client.get('/')
|
||||
assert 200 == resp.status
|
||||
yield from resp.release()
|
||||
|
||||
self.loop.run_until_complete(go())
|
||||
|
|
Loading…
Reference in New Issue