From 29869c710fddb142d9eaa164b63cc14506b010e4 Mon Sep 17 00:00:00 2001
From: Andrew Svetlov <andrew.svetlov@gmail.com>
Date: Wed, 5 Aug 2015 13:32:49 +0300
Subject: [PATCH] Add test for remember/forbid when library was not setted up

---
 aiohttp_security/api.py   | 19 ++++++++-
 tests/test_no_identity.py | 82 +++++++++++++++++++++++++++++++++++++++
 2 files changed, 99 insertions(+), 2 deletions(-)
 create mode 100644 tests/test_no_identity.py

diff --git a/aiohttp_security/api.py b/aiohttp_security/api.py
index 58497a3..effe783 100644
--- a/aiohttp_security/api.py
+++ b/aiohttp_security/api.py
@@ -1,4 +1,5 @@
 import asyncio
+from aiohttp import web
 from aiohttp_security.abc import (AbstractIdentityPolicy,
                                   AbstractAuthorizationPolicy)
 
@@ -8,13 +9,27 @@ AUTZ_KEY = 'aiohttp_security_autz_policy'
 
 @asyncio.coroutine
 def remember(request, response, identity, **kwargs):
-    identity_policy = request.app[IDENTITY_KEY]
+    identity_policy = request.app.get(IDENTITY_KEY)
+    if identity_policy is None:
+        text = ("Security subsystem is not initialized, "
+                "call aiohttp_security.setup(...) first")
+        # in order to see meaningful exception message both: on console
+        # output and rendered page we add same message to *reason* and
+        # *text* arguments.
+        raise web.HTTPInternalServerError(reason=text, text=text)
     yield from identity_policy.remember(request, response, identity, **kwargs)
 
 
 @asyncio.coroutine
 def forget(request, response):
-    identity_policy = request.app[IDENTITY_KEY]
+    identity_policy = request.app.get(IDENTITY_KEY)
+    if identity_policy is None:
+        text = ("Security subsystem is not initialized, "
+                "call aiohttp_security.setup(...) first")
+        # in order to see meaningful exception message both: on console
+        # output and rendered page we add same message to *reason* and
+        # *text* arguments.
+        raise web.HTTPInternalServerError(reason=text, text=text)
     yield from identity_policy.forget(request, response)
 
 
diff --git a/tests/test_no_identity.py b/tests/test_no_identity.py
new file mode 100644
index 0000000..e45fd12
--- /dev/null
+++ b/tests/test_no_identity.py
@@ -0,0 +1,82 @@
+import asyncio
+import socket
+import unittest
+
+import aiohttp
+from aiohttp import web
+from aiohttp_security import remember, forget
+
+
+class TestNoAuth(unittest.TestCase):
+
+    def setUp(self):
+        self.loop = asyncio.new_event_loop()
+        asyncio.set_event_loop(None)
+        self.client = aiohttp.ClientSession(loop=self.loop)
+
+    def tearDown(self):
+        self.client.close()
+        self.loop.run_until_complete(self.handler.finish_connections())
+        self.srv.close()
+        self.loop.run_until_complete(self.srv.wait_closed())
+        self.loop.close()
+
+    def find_unused_port(self):
+        s = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
+        s.bind(('127.0.0.1', 0))
+        port = s.getsockname()[1]
+        s.close()
+        return port
+
+    @asyncio.coroutine
+    def create_server(self):
+        app = web.Application(loop=self.loop)
+
+        port = self.find_unused_port()
+        self.handler = app.make_handler(
+            debug=False, keep_alive_on=False)
+        srv = yield from self.loop.create_server(
+            self.handler, '127.0.0.1', port)
+        url = "http://127.0.0.1:{}/".format(port)
+        self.srv = srv
+        return app, srv, url
+
+    def test_remember(self):
+
+        @asyncio.coroutine
+        def do_remember(request):
+            response = web.Response()
+            yield from remember(request, response, 'Andrew')
+
+        @asyncio.coroutine
+        def go():
+            app, srv, url = yield from self.create_server()
+            app.router.add_route('POST', '/', do_remember)
+            resp = yield from self.client.post(url)
+            self.assertEqual(500, resp.status)
+            self.assertEqual(('Security subsystem is not initialized, '
+                              'call aiohttp_security.setup(...) first'),
+                             resp.reason)
+            yield from resp.release()
+
+        self.loop.run_until_complete(go())
+
+    def test_forget(self):
+
+        @asyncio.coroutine
+        def do_forget(request):
+            response = web.Response()
+            yield from forget(request, response)
+
+        @asyncio.coroutine
+        def go():
+            app, srv, url = yield from self.create_server()
+            app.router.add_route('POST', '/', do_forget)
+            resp = yield from self.client.post(url)
+            self.assertEqual(500, resp.status)
+            self.assertEqual(('Security subsystem is not initialized, '
+                              'call aiohttp_security.setup(...) first'),
+                             resp.reason)
+            yield from resp.release()
+
+        self.loop.run_until_complete(go())