From 19d7ee7b06e2c924ae6abfecb7ca72883fb3b2ee Mon Sep 17 00:00:00 2001 From: Alex Kuzmenko Date: Mon, 6 Feb 2017 01:11:32 +0200 Subject: [PATCH] Add `login_required` decorator --- aiohttp_security/__init__.py | 11 +++++++++-- aiohttp_security/api.py | 25 ++++++++++++++++++++++++- demo/handlers.py | 28 +++++++++------------------- 3 files changed, 42 insertions(+), 22 deletions(-) diff --git a/aiohttp_security/__init__.py b/aiohttp_security/__init__.py index 3636192..5a6446b 100644 --- a/aiohttp_security/__init__.py +++ b/aiohttp_security/__init__.py @@ -1,5 +1,12 @@ from .abc import AbstractIdentityPolicy, AbstractAuthorizationPolicy -from .api import remember, forget, setup, authorized_userid, permits +from .api import ( + authorized_userid, + forget, + login_required, + permits, + remember, + setup, +) from .cookies_identity import CookiesIdentityPolicy from .session_identity import SessionIdentityPolicy @@ -10,4 +17,4 @@ __version__ = '0.1.1' __all__ = ('AbstractIdentityPolicy', 'AbstractAuthorizationPolicy', 'CookiesIdentityPolicy', 'SessionIdentityPolicy', 'remember', 'forget', 'authorized_userid', - 'permits', 'setup') + 'permits', 'setup', 'login_required') diff --git a/aiohttp_security/api.py b/aiohttp_security/api.py index 831f8de..950cfe1 100644 --- a/aiohttp_security/api.py +++ b/aiohttp_security/api.py @@ -1,3 +1,5 @@ +from functools import wraps + import asyncio from aiohttp import web from aiohttp_security.abc import (AbstractIdentityPolicy, @@ -5,6 +7,7 @@ from aiohttp_security.abc import (AbstractIdentityPolicy, IDENTITY_KEY = 'aiohttp_security_identity_policy' AUTZ_KEY = 'aiohttp_security_autz_policy' +AUTZ_REDIRECT_URL = 'aiohttp_security_autz_redirect_url' @asyncio.coroutine @@ -74,9 +77,29 @@ def permits(request, permission, context=None): return access -def setup(app, identity_policy, autz_policy): +def setup(app, identity_policy, autz_policy, redirect_url=None): assert isinstance(identity_policy, AbstractIdentityPolicy), identity_policy assert isinstance(autz_policy, AbstractAuthorizationPolicy), autz_policy app[IDENTITY_KEY] = identity_policy app[AUTZ_KEY] = autz_policy + + if redirect_url: + app[AUTZ_REDIRECT_URL] = redirect_url + + +def login_required(permission): + def wrapper(handler): + @asyncio.coroutine + @wraps(handler) + def wrapped(request): + has_perm = yield from permits(request, permission) + if not has_perm: + redirect_url = request.app.get(AUTZ_REDIRECT_URL) + if redirect_url: + return web.HTTPFound(redirect_url) + raise web.HTTPForbidden() + response = yield from handler(request) + return response + return wrapped + return wrapper diff --git a/demo/handlers.py b/demo/handlers.py index 3de2aed..8c54379 100644 --- a/demo/handlers.py +++ b/demo/handlers.py @@ -1,27 +1,17 @@ import asyncio -import functools from aiohttp import web -from aiohttp_security import remember, forget, authorized_userid, permits +from aiohttp_security import ( + authorized_userid, + forget, + login_required, + remember, +) from .db_auth import check_credentials -def require(permission): - def wrapper(f): - @asyncio.coroutine - @functools.wraps(f) - def wrapped(self, request): - has_perm = yield from permits(request, permission) - if not has_perm: - message = 'User has no permission {}'.format(permission) - raise web.HTTPForbidden(body=message.encode()) - return (yield from f(self, request)) - return wrapped - return wrapper - - class Web(object): index_template = """ @@ -65,21 +55,21 @@ class Web(object): return web.HTTPUnauthorized( body=b'Invalid username/password combination') - @require('public') + @login_required('public') @asyncio.coroutine def logout(self, request): response = web.Response(body=b'You have been logged out') yield from forget(request, response) return response - @require('public') + @login_required('public') @asyncio.coroutine def internal_page(self, request): response = web.Response( body=b'This page is visible for all registered users') return response - @require('protected') + @login_required('protected') @asyncio.coroutine def protected_page(self, request): response = web.Response(body=b'You are on protected page')