From e29666cea97e9087e32666a56f5957aa101839a0 Mon Sep 17 00:00:00 2001
From: Andrew Svetlov <andrew.svetlov@gmail.com>
Date: Thu, 19 Nov 2015 13:53:38 +0200
Subject: [PATCH] Convert test_no_auth into pytest style

---
 tests/conftest.py     | 139 ++++++++++++++++++++++++++++++++++++++++++
 tests/test_no_auth.py |  99 +++++++++---------------------
 2 files changed, 168 insertions(+), 70 deletions(-)
 create mode 100644 tests/conftest.py

diff --git a/tests/conftest.py b/tests/conftest.py
new file mode 100644
index 0000000..180a477
--- /dev/null
+++ b/tests/conftest.py
@@ -0,0 +1,139 @@
+import aiohttp
+import asyncio
+import gc
+import pytest
+import socket
+from aiohttp import web
+
+
+@pytest.fixture
+def unused_port():
+    def f():
+        with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
+            s.bind(('127.0.0.1', 0))
+            return s.getsockname()[1]
+    return f
+
+
+@pytest.yield_fixture
+def loop(request):
+    loop = asyncio.new_event_loop()
+    asyncio.set_event_loop(None)
+
+    yield loop
+
+    loop.stop()
+    loop.run_forever()
+    loop.close()
+    gc.collect()
+    asyncio.set_event_loop(None)
+
+
+@pytest.yield_fixture
+def create_server(loop, unused_port):
+    app = handler = srv = None
+
+    @asyncio.coroutine
+    def create(*, debug=False, ssl_ctx=None, proto='http'):
+        nonlocal app, handler, srv
+        app = web.Application(loop=loop)
+        port = unused_port()
+        handler = app.make_handler(debug=debug, keep_alive_on=False)
+        srv = yield from loop.create_server(handler, '127.0.0.1', port,
+                                            ssl=ssl_ctx)
+        if ssl_ctx:
+            proto += 's'
+        url = "{}://127.0.0.1:{}".format(proto, port)
+        return app, url
+
+    yield create
+
+    @asyncio.coroutine
+    def finish():
+        yield from handler.finish_connections()
+        yield from app.finish()
+        srv.close()
+        yield from srv.wait_closed()
+
+    loop.run_until_complete(finish())
+
+
+class Client:
+    def __init__(self, session, url):
+        self._session = session
+        if not url.endswith('/'):
+            url += '/'
+        self._url = url
+
+    def close(self):
+        self._session.close()
+
+    def get(self, path, **kwargs):
+        while path.startswith('/'):
+            path = path[1:]
+        url = self._url + path
+        return self._session.get(url, **kwargs)
+
+    def post(self, path, **kwargs):
+        while path.startswith('/'):
+            path = path[1:]
+        url = self._url + path
+        return self._session.post(url, **kwargs)
+
+    def ws_connect(self, path, **kwargs):
+        while path.startswith('/'):
+            path = path[1:]
+        url = self._url + path
+        return self._session.ws_connect(url, **kwargs)
+
+
+@pytest.yield_fixture
+def create_app_and_client(create_server, loop):
+    client = None
+
+    @asyncio.coroutine
+    def maker(*, server_params=None, client_params=None):
+        nonlocal client
+        if server_params is None:
+            server_params = {}
+        server_params.setdefault('debug', False)
+        server_params.setdefault('ssl_ctx', None)
+        app, url = yield from create_server(**server_params)
+        if client_params is None:
+            client_params = {}
+        client = Client(aiohttp.ClientSession(loop=loop, **client_params), url)
+        return app, client
+
+    yield maker
+    client.close()
+
+
+@pytest.mark.tryfirst
+def pytest_pycollect_makeitem(collector, name, obj):
+    if collector.funcnamefilter(name):
+        if not callable(obj):
+            return
+        item = pytest.Function(name, parent=collector)
+        if 'run_loop' in item.keywords:
+            return list(collector._genfunctions(name, obj))
+
+
+@pytest.mark.tryfirst
+def pytest_pyfunc_call(pyfuncitem):
+    """
+    Run asyncio marked test functions in an event loop instead of a normal
+    function call.
+    """
+    if 'run_loop' in pyfuncitem.keywords:
+        funcargs = pyfuncitem.funcargs
+        loop = funcargs['loop']
+        testargs = {arg: funcargs[arg]
+                    for arg in pyfuncitem._fixtureinfo.argnames}
+        loop.run_until_complete(pyfuncitem.obj(**testargs))
+        return True
+
+
+def pytest_runtest_setup(item):
+    if 'run_loop' in item.keywords and 'loop' not in item.fixturenames:
+        # inject an event loop fixture for all async tests
+        item.fixturenames.append('loop')
diff --git a/tests/test_no_auth.py b/tests/test_no_auth.py
index ca3c04a..8c3d99e 100644
--- a/tests/test_no_auth.py
+++ b/tests/test_no_auth.py
@@ -1,82 +1,41 @@
 import asyncio
-import socket
-import unittest
+import pytest
 
-import aiohttp
 from aiohttp import web
-from aiohttp_security import (authorized_userid, permits)
+from aiohttp_security import authorized_userid, permits
 
 
-class TestNoAuth(unittest.TestCase):
-
-    def setUp(self):
-        self.loop = asyncio.new_event_loop()
-        asyncio.set_event_loop(None)
-        self.client = aiohttp.ClientSession(loop=self.loop)
-
-    def tearDown(self):
-        self.client.close()
-        self.loop.run_until_complete(self.handler.finish_connections())
-        self.srv.close()
-        self.loop.run_until_complete(self.srv.wait_closed())
-        self.loop.close()
-
-    def find_unused_port(self):
-        s = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
-        s.bind(('127.0.0.1', 0))
-        port = s.getsockname()[1]
-        s.close()
-        return port
+@pytest.mark.run_loop
+def test_authorized_userid(create_app_and_client):
 
     @asyncio.coroutine
-    def create_server(self):
-        app = web.Application(loop=self.loop)
+    def check(request):
+        userid = yield from authorized_userid(request)
+        assert userid is None
+        return web.Response()
 
-        port = self.find_unused_port()
-        self.handler = app.make_handler(
-            debug=False, keep_alive_on=False)
-        srv = yield from self.loop.create_server(
-            self.handler, '127.0.0.1', port)
-        url = "http://127.0.0.1:{}/".format(port)
-        self.srv = srv
-        return app, srv, url
+    app, client = yield from create_app_and_client()
+    app.router.add_route('GET', '/', check)
+    resp = yield from client.get('/')
+    assert 200 == resp.status
+    yield from resp.release()
 
-    def test_authorized_userid(self):
 
-        @asyncio.coroutine
-        def check(request):
-            userid = yield from authorized_userid(request)
-            self.assertIsNone(userid)
-            return web.Response()
+@pytest.mark.run_loop
+def test_permits(create_app_and_client):
 
-        @asyncio.coroutine
-        def go():
-            app, srv, url = yield from self.create_server()
-            app.router.add_route('GET', '/', check)
-            resp = yield from self.client.get(url)
-            self.assertEqual(200, resp.status)
-            yield from resp.release()
+    @asyncio.coroutine
+    def check(request):
+        ret = yield from permits(request, 'read')
+        assert ret
+        ret = yield from permits(request, 'write')
+        assert ret
+        ret = yield from permits(request, 'unknown')
+        assert ret
+        return web.Response()
 
-        self.loop.run_until_complete(go())
-
-    def test_permits(self):
-
-        @asyncio.coroutine
-        def check(request):
-            ret = yield from permits(request, 'read')
-            self.assertTrue(ret)
-            ret = yield from permits(request, 'write')
-            self.assertTrue(ret)
-            ret = yield from permits(request, 'unknown')
-            self.assertTrue(ret)
-            return web.Response()
-
-        @asyncio.coroutine
-        def go():
-            app, srv, url = yield from self.create_server()
-            app.router.add_route('GET', '/', check)
-            resp = yield from self.client.get(url)
-            self.assertEqual(200, resp.status)
-            yield from resp.release()
-
-        self.loop.run_until_complete(go())
+    app, client = yield from create_app_and_client()
+    app.router.add_route('GET', '/', check)
+    resp = yield from client.get('/')
+    assert 200 == resp.status
+    yield from resp.release()