From 19d7ee7b06e2c924ae6abfecb7ca72883fb3b2ee Mon Sep 17 00:00:00 2001
From: Alex Kuzmenko <alx.kuzm@gmail.com>
Date: Mon, 6 Feb 2017 01:11:32 +0200
Subject: [PATCH] Add `login_required` decorator

---
 aiohttp_security/__init__.py | 11 +++++++++--
 aiohttp_security/api.py      | 25 ++++++++++++++++++++++++-
 demo/handlers.py             | 28 +++++++++-------------------
 3 files changed, 42 insertions(+), 22 deletions(-)

diff --git a/aiohttp_security/__init__.py b/aiohttp_security/__init__.py
index 3636192..5a6446b 100644
--- a/aiohttp_security/__init__.py
+++ b/aiohttp_security/__init__.py
@@ -1,5 +1,12 @@
 from .abc import AbstractIdentityPolicy, AbstractAuthorizationPolicy
-from .api import remember, forget, setup, authorized_userid, permits
+from .api import (
+    authorized_userid,
+    forget,
+    login_required,
+    permits,
+    remember,
+    setup,
+)
 from .cookies_identity import CookiesIdentityPolicy
 from .session_identity import SessionIdentityPolicy
 
@@ -10,4 +17,4 @@ __version__ = '0.1.1'
 __all__ = ('AbstractIdentityPolicy', 'AbstractAuthorizationPolicy',
            'CookiesIdentityPolicy', 'SessionIdentityPolicy',
            'remember', 'forget', 'authorized_userid',
-           'permits', 'setup')
+           'permits', 'setup', 'login_required')
diff --git a/aiohttp_security/api.py b/aiohttp_security/api.py
index 831f8de..950cfe1 100644
--- a/aiohttp_security/api.py
+++ b/aiohttp_security/api.py
@@ -1,3 +1,5 @@
+from functools import wraps
+
 import asyncio
 from aiohttp import web
 from aiohttp_security.abc import (AbstractIdentityPolicy,
@@ -5,6 +7,7 @@ from aiohttp_security.abc import (AbstractIdentityPolicy,
 
 IDENTITY_KEY = 'aiohttp_security_identity_policy'
 AUTZ_KEY = 'aiohttp_security_autz_policy'
+AUTZ_REDIRECT_URL = 'aiohttp_security_autz_redirect_url'
 
 
 @asyncio.coroutine
@@ -74,9 +77,29 @@ def permits(request, permission, context=None):
     return access
 
 
-def setup(app, identity_policy, autz_policy):
+def setup(app, identity_policy, autz_policy, redirect_url=None):
     assert isinstance(identity_policy, AbstractIdentityPolicy), identity_policy
     assert isinstance(autz_policy, AbstractAuthorizationPolicy), autz_policy
 
     app[IDENTITY_KEY] = identity_policy
     app[AUTZ_KEY] = autz_policy
+
+    if redirect_url:
+        app[AUTZ_REDIRECT_URL] = redirect_url
+
+
+def login_required(permission):
+    def wrapper(handler):
+        @asyncio.coroutine
+        @wraps(handler)
+        def wrapped(request):
+            has_perm = yield from permits(request, permission)
+            if not has_perm:
+                redirect_url = request.app.get(AUTZ_REDIRECT_URL)
+                if redirect_url:
+                    return web.HTTPFound(redirect_url)
+                raise web.HTTPForbidden()
+            response = yield from handler(request)
+            return response
+        return wrapped
+    return wrapper
diff --git a/demo/handlers.py b/demo/handlers.py
index 3de2aed..8c54379 100644
--- a/demo/handlers.py
+++ b/demo/handlers.py
@@ -1,27 +1,17 @@
 import asyncio
-import functools
 
 from aiohttp import web
 
-from aiohttp_security import remember, forget, authorized_userid, permits
+from aiohttp_security import (
+    authorized_userid,
+    forget,
+    login_required,
+    remember,
+)
 
 from .db_auth import check_credentials
 
 
-def require(permission):
-    def wrapper(f):
-        @asyncio.coroutine
-        @functools.wraps(f)
-        def wrapped(self, request):
-            has_perm = yield from permits(request, permission)
-            if not has_perm:
-                message = 'User has no permission {}'.format(permission)
-                raise web.HTTPForbidden(body=message.encode())
-            return (yield from f(self, request))
-        return wrapped
-    return wrapper
-
-
 class Web(object):
     index_template = """
 <!doctype html>
@@ -65,21 +55,21 @@ class Web(object):
         return web.HTTPUnauthorized(
             body=b'Invalid username/password combination')
 
-    @require('public')
+    @login_required('public')
     @asyncio.coroutine
     def logout(self, request):
         response = web.Response(body=b'You have been logged out')
         yield from forget(request, response)
         return response
 
-    @require('public')
+    @login_required('public')
     @asyncio.coroutine
     def internal_page(self, request):
         response = web.Response(
             body=b'This page is visible for all registered users')
         return response
 
-    @require('protected')
+    @login_required('protected')
     @asyncio.coroutine
     def protected_page(self, request):
         response = web.Response(body=b'You are on protected page')