Add `login_required` decorator

This commit is contained in:
Alex Kuzmenko 2017-02-06 01:11:32 +02:00
parent 6e4355ce3c
commit 19d7ee7b06
3 changed files with 42 additions and 22 deletions

View File

@ -1,5 +1,12 @@
from .abc import AbstractIdentityPolicy, AbstractAuthorizationPolicy from .abc import AbstractIdentityPolicy, AbstractAuthorizationPolicy
from .api import remember, forget, setup, authorized_userid, permits from .api import (
authorized_userid,
forget,
login_required,
permits,
remember,
setup,
)
from .cookies_identity import CookiesIdentityPolicy from .cookies_identity import CookiesIdentityPolicy
from .session_identity import SessionIdentityPolicy from .session_identity import SessionIdentityPolicy
@ -10,4 +17,4 @@ __version__ = '0.1.1'
__all__ = ('AbstractIdentityPolicy', 'AbstractAuthorizationPolicy', __all__ = ('AbstractIdentityPolicy', 'AbstractAuthorizationPolicy',
'CookiesIdentityPolicy', 'SessionIdentityPolicy', 'CookiesIdentityPolicy', 'SessionIdentityPolicy',
'remember', 'forget', 'authorized_userid', 'remember', 'forget', 'authorized_userid',
'permits', 'setup') 'permits', 'setup', 'login_required')

View File

@ -1,3 +1,5 @@
from functools import wraps
import asyncio import asyncio
from aiohttp import web from aiohttp import web
from aiohttp_security.abc import (AbstractIdentityPolicy, from aiohttp_security.abc import (AbstractIdentityPolicy,
@ -5,6 +7,7 @@ from aiohttp_security.abc import (AbstractIdentityPolicy,
IDENTITY_KEY = 'aiohttp_security_identity_policy' IDENTITY_KEY = 'aiohttp_security_identity_policy'
AUTZ_KEY = 'aiohttp_security_autz_policy' AUTZ_KEY = 'aiohttp_security_autz_policy'
AUTZ_REDIRECT_URL = 'aiohttp_security_autz_redirect_url'
@asyncio.coroutine @asyncio.coroutine
@ -74,9 +77,29 @@ def permits(request, permission, context=None):
return access return access
def setup(app, identity_policy, autz_policy): def setup(app, identity_policy, autz_policy, redirect_url=None):
assert isinstance(identity_policy, AbstractIdentityPolicy), identity_policy assert isinstance(identity_policy, AbstractIdentityPolicy), identity_policy
assert isinstance(autz_policy, AbstractAuthorizationPolicy), autz_policy assert isinstance(autz_policy, AbstractAuthorizationPolicy), autz_policy
app[IDENTITY_KEY] = identity_policy app[IDENTITY_KEY] = identity_policy
app[AUTZ_KEY] = autz_policy app[AUTZ_KEY] = autz_policy
if redirect_url:
app[AUTZ_REDIRECT_URL] = redirect_url
def login_required(permission):
def wrapper(handler):
@asyncio.coroutine
@wraps(handler)
def wrapped(request):
has_perm = yield from permits(request, permission)
if not has_perm:
redirect_url = request.app.get(AUTZ_REDIRECT_URL)
if redirect_url:
return web.HTTPFound(redirect_url)
raise web.HTTPForbidden()
response = yield from handler(request)
return response
return wrapped
return wrapper

View File

@ -1,27 +1,17 @@
import asyncio import asyncio
import functools
from aiohttp import web from aiohttp import web
from aiohttp_security import remember, forget, authorized_userid, permits from aiohttp_security import (
authorized_userid,
forget,
login_required,
remember,
)
from .db_auth import check_credentials from .db_auth import check_credentials
def require(permission):
def wrapper(f):
@asyncio.coroutine
@functools.wraps(f)
def wrapped(self, request):
has_perm = yield from permits(request, permission)
if not has_perm:
message = 'User has no permission {}'.format(permission)
raise web.HTTPForbidden(body=message.encode())
return (yield from f(self, request))
return wrapped
return wrapper
class Web(object): class Web(object):
index_template = """ index_template = """
<!doctype html> <!doctype html>
@ -65,21 +55,21 @@ class Web(object):
return web.HTTPUnauthorized( return web.HTTPUnauthorized(
body=b'Invalid username/password combination') body=b'Invalid username/password combination')
@require('public') @login_required('public')
@asyncio.coroutine @asyncio.coroutine
def logout(self, request): def logout(self, request):
response = web.Response(body=b'You have been logged out') response = web.Response(body=b'You have been logged out')
yield from forget(request, response) yield from forget(request, response)
return response return response
@require('public') @login_required('public')
@asyncio.coroutine @asyncio.coroutine
def internal_page(self, request): def internal_page(self, request):
response = web.Response( response = web.Response(
body=b'This page is visible for all registered users') body=b'This page is visible for all registered users')
return response return response
@require('protected') @login_required('protected')
@asyncio.coroutine @asyncio.coroutine
def protected_page(self, request): def protected_page(self, request):
response = web.Response(body=b'You are on protected page') response = web.Response(body=b'You are on protected page')