Add type annotations.

This commit is contained in:
Sam Bull 2020-12-18 17:58:38 +00:00
parent 2247eb72f9
commit b3b5367460
20 changed files with 200 additions and 112 deletions

37
.mypy.ini Normal file
View File

@ -0,0 +1,37 @@
[mypy]
files = aiohttp_security, demo, tests
check_untyped_defs = True
follow_imports_for_stubs = True
disallow_any_decorated = True
disallow_any_generics = True
disallow_incomplete_defs = True
disallow_subclassing_any = True
disallow_untyped_calls = True
disallow_untyped_decorators = True
disallow_untyped_defs = True
implicit_reexport = False
no_implicit_optional = True
show_error_codes = True
strict_equality = True
warn_incomplete_stub = True
warn_redundant_casts = True
warn_unreachable = True
warn_unused_ignores = True
disallow_any_unimported = True
warn_return_any = True
[mypy-aiohttp_security.abc.*]
disallow_any_decorated = False
[mypy-tests.*]
disallow_any_decorated = False
disallow_untyped_defs = False
[mypy-aiopg.*]
ignore_missing_imports = True
[mypy-aioredis.*]
ignore_missing_imports = True
[mypy-passlib.*]
ignore_missing_imports = True

View File

@ -1,4 +1,8 @@
import abc import abc
from enum import Enum
from typing import Any, Optional, Union
from aiohttp import web
# see http://plope.com/pyramid_auth_design_api_postmortem # see http://plope.com/pyramid_auth_design_api_postmortem
@ -6,13 +10,14 @@ import abc
class AbstractIdentityPolicy(metaclass=abc.ABCMeta): class AbstractIdentityPolicy(metaclass=abc.ABCMeta):
@abc.abstractmethod @abc.abstractmethod
async def identify(self, request): async def identify(self, request: web.Request) -> Optional[str]:
"""Return the claimed identity of the user associated request or """Return the claimed identity of the user associated request or
``None`` if no identity can be found associated with the request.""" ``None`` if no identity can be found associated with the request."""
pass pass
@abc.abstractmethod @abc.abstractmethod
async def remember(self, request, response, identity, **kwargs): async def remember(self, request: web.Request, response: web.StreamResponse,
identity: str, **kwargs: Any) -> None:
"""Remember identity. """Remember identity.
Modify response object by filling it's headers with remembered user. Modify response object by filling it's headers with remembered user.
@ -23,7 +28,7 @@ class AbstractIdentityPolicy(metaclass=abc.ABCMeta):
pass pass
@abc.abstractmethod @abc.abstractmethod
async def forget(self, request, response): async def forget(self, request: web.Request, response: web.StreamResponse) -> None:
""" Modify response which can be used to 'forget' the """ Modify response which can be used to 'forget' the
current identity on subsequent requests.""" current identity on subsequent requests."""
pass pass
@ -32,7 +37,8 @@ class AbstractIdentityPolicy(metaclass=abc.ABCMeta):
class AbstractAuthorizationPolicy(metaclass=abc.ABCMeta): class AbstractAuthorizationPolicy(metaclass=abc.ABCMeta):
@abc.abstractmethod @abc.abstractmethod
async def permits(self, identity, permission, context=None): async def permits(self, identity: str, permission: Union[str, Enum],
context: Any = None) -> bool:
"""Check user permissions. """Check user permissions.
Return True if the identity is allowed the permission in the Return True if the identity is allowed the permission in the
@ -41,7 +47,7 @@ class AbstractAuthorizationPolicy(metaclass=abc.ABCMeta):
pass pass
@abc.abstractmethod @abc.abstractmethod
async def authorized_userid(self, identity): async def authorized_userid(self, identity: str) -> Optional[str]:
"""Retrieve authorized user id. """Retrieve authorized user id.
Return the user_id of the user identified by the identity Return the user_id of the user identified by the identity

View File

@ -1,15 +1,23 @@
import enum import enum
import warnings import warnings
from aiohttp import web
from aiohttp_security.abc import (AbstractIdentityPolicy,
AbstractAuthorizationPolicy)
from functools import wraps from functools import wraps
from typing import Any, Callable, Optional, TypeVar, Union
from aiohttp import web
from aiohttp_security.abc import AbstractAuthorizationPolicy, AbstractIdentityPolicy
IDENTITY_KEY = 'aiohttp_security_identity_policy' IDENTITY_KEY = 'aiohttp_security_identity_policy'
AUTZ_KEY = 'aiohttp_security_autz_policy' AUTZ_KEY = 'aiohttp_security_autz_policy'
# _AIP/_AAP are shorthand for Optional[policy] when we retrieve from request.
_AAP = Optional[AbstractAuthorizationPolicy]
_AIP = Optional[AbstractIdentityPolicy]
_Handler = TypeVar('_Handler', bound=Union[Callable[[web.Request], Any],
Callable[[object, web.Request], Any]])
async def remember(request, response, identity, **kwargs):
async def remember(request: web.Request, response: web.StreamResponse,
identity: str, **kwargs: Any) -> None:
"""Remember identity into response. """Remember identity into response.
The action is performed by identity_policy.remember() The action is performed by identity_policy.remember()
@ -30,7 +38,7 @@ async def remember(request, response, identity, **kwargs):
await identity_policy.remember(request, response, identity, **kwargs) await identity_policy.remember(request, response, identity, **kwargs)
async def forget(request, response): async def forget(request: web.Request, response: web.StreamResponse) -> None:
"""Forget previously remembered identity. """Forget previously remembered identity.
Usually it clears cookie or server-side storage to forget user Usually it clears cookie or server-side storage to forget user
@ -47,9 +55,9 @@ async def forget(request, response):
await identity_policy.forget(request, response) await identity_policy.forget(request, response)
async def authorized_userid(request): async def authorized_userid(request: web.Request) -> Optional[str]:
identity_policy = request.config_dict.get(IDENTITY_KEY) identity_policy: _AIP = request.config_dict.get(IDENTITY_KEY)
autz_policy = request.config_dict.get(AUTZ_KEY) autz_policy: _AAP = request.config_dict.get(AUTZ_KEY)
if identity_policy is None or autz_policy is None: if identity_policy is None or autz_policy is None:
return None return None
identity = await identity_policy.identify(request) identity = await identity_policy.identify(request)
@ -59,20 +67,21 @@ async def authorized_userid(request):
return user_id return user_id
async def permits(request, permission, context=None): async def permits(request: web.Request, permission: Union[str, enum.Enum],
context: Any = None) -> bool:
assert isinstance(permission, (str, enum.Enum)), permission assert isinstance(permission, (str, enum.Enum)), permission
assert permission assert permission
identity_policy = request.config_dict.get(IDENTITY_KEY) identity_policy: _AIP = request.config_dict.get(IDENTITY_KEY)
autz_policy = request.config_dict.get(AUTZ_KEY) autz_policy: _AAP = request.config_dict.get(AUTZ_KEY)
if identity_policy is None or autz_policy is None: if identity_policy is None or autz_policy is None:
return True return True
identity = await identity_policy.identify(request) identity = await identity_policy.identify(request)
# non-registered user still may has some permissions # non-registered user still may have some permissions
access = await autz_policy.permits(identity, permission, context) access = await autz_policy.permits(identity, permission, context)
return access return access
async def is_anonymous(request): async def is_anonymous(request: web.Request) -> bool:
"""Check if user is anonymous. """Check if user is anonymous.
User is considered anonymous if there is not identity User is considered anonymous if there is not identity
@ -87,7 +96,7 @@ async def is_anonymous(request):
return False return False
async def check_authorized(request): async def check_authorized(request: web.Request) -> str:
"""Checker that raises HTTPUnauthorized for anonymous users. """Checker that raises HTTPUnauthorized for anonymous users.
""" """
userid = await authorized_userid(request) userid = await authorized_userid(request)
@ -96,31 +105,32 @@ async def check_authorized(request):
return userid return userid
def login_required(fn): def login_required(fn: _Handler) -> _Handler:
"""Decorator that restrict access only for authorized users. """Decorator that restrict access only for authorized users.
User is considered authorized if authorized_userid User is considered authorized if authorized_userid
returns some value. returns some value.
""" """
@wraps(fn) @wraps(fn)
async def wrapped(*args, **kwargs): async def wrapped(*args: Union[object, web.Request]) -> Any:
request = args[-1] request = args[-1]
if not isinstance(request, web.BaseRequest): if not isinstance(request, web.Request):
msg = ("Incorrect decorator usage. " msg = ("Incorrect decorator usage. "
"Expecting `def handler(request)` " "Expecting `def handler(request)` "
"or `def handler(self, request)`.") "or `def handler(self, request)`.")
raise RuntimeError(msg) raise RuntimeError(msg)
await check_authorized(request) await check_authorized(request)
return await fn(*args, **kwargs) return await fn(*args) # type: ignore[arg-type]
warnings.warn("login_required decorator is deprecated, " warnings.warn("login_required decorator is deprecated, "
"use check_authorized instead", "use check_authorized instead",
DeprecationWarning) DeprecationWarning)
return wrapped return wrapped # type: ignore[return-value]
async def check_permission(request, permission, context=None): async def check_permission(request: web.Request, permission: Union[str, enum.Enum],
context: Any = None) -> None:
"""Checker that passes only to authoraised users with given permission. """Checker that passes only to authoraised users with given permission.
If user is not authorized - raises HTTPUnauthorized, If user is not authorized - raises HTTPUnauthorized,
@ -134,10 +144,7 @@ async def check_permission(request, permission, context=None):
raise web.HTTPForbidden() raise web.HTTPForbidden()
def has_permission( def has_permission(permission: Union[str, enum.Enum], context: Any = None): # type: ignore
permission,
context=None,
):
"""Decorator that restricts access only for authorized users """Decorator that restricts access only for authorized users
with correct permissions. with correct permissions.
@ -145,11 +152,11 @@ def has_permission(
if user is authorized and does not have permission - if user is authorized and does not have permission -
raises HTTPForbidden. raises HTTPForbidden.
""" """
def wrapper(fn): def wrapper(fn): # type: ignore
@wraps(fn) @wraps(fn)
async def wrapped(*args, **kwargs): async def wrapped(*args, **kwargs): # type: ignore
request = args[-1] request = args[-1]
if not isinstance(request, web.BaseRequest): if not isinstance(request, web.Request):
msg = ("Incorrect decorator usage. " msg = ("Incorrect decorator usage. "
"Expecting `def handler(request)` " "Expecting `def handler(request)` "
"or `def handler(self, request)`.") "or `def handler(self, request)`.")
@ -166,7 +173,8 @@ def has_permission(
return wrapper return wrapper
def setup(app, identity_policy, autz_policy): def setup(app: web.Application, identity_policy: AbstractIdentityPolicy,
autz_policy: AbstractAuthorizationPolicy) -> None:
assert isinstance(identity_policy, AbstractIdentityPolicy), identity_policy assert isinstance(identity_policy, AbstractIdentityPolicy), identity_policy
assert isinstance(autz_policy, AbstractAuthorizationPolicy), autz_policy assert isinstance(autz_policy, AbstractAuthorizationPolicy), autz_policy

View File

@ -5,28 +5,32 @@ more handy.
""" """
from aiohttp import web
from typing import Any, NewType, Optional, Union, cast
from .abc import AbstractIdentityPolicy from .abc import AbstractIdentityPolicy
_Sentinel = NewType('_Sentinel', object)
sentinel = object() sentinel = _Sentinel(object())
class CookiesIdentityPolicy(AbstractIdentityPolicy): class CookiesIdentityPolicy(AbstractIdentityPolicy):
def __init__(self): def __init__(self) -> None:
self._cookie_name = 'AIOHTTP_SECURITY' self._cookie_name = 'AIOHTTP_SECURITY'
self._max_age = 30 * 24 * 3600 self._max_age = 30 * 24 * 3600
async def identify(self, request): async def identify(self, request: web.Request) -> Optional[str]:
identity = request.cookies.get(self._cookie_name) return request.cookies.get(self._cookie_name)
return identity
async def remember(self, request, response, identity, max_age=sentinel, async def remember(self, request: web.Request, response: web.StreamResponse,
**kwargs): identity: str, max_age: Union[_Sentinel, Optional[int]] = sentinel,
**kwargs: Any) -> None:
if max_age is sentinel: if max_age is sentinel:
max_age = self._max_age max_age = self._max_age
max_age = cast(Optional[int], max_age)
response.set_cookie(self._cookie_name, identity, response.set_cookie(self._cookie_name, identity,
max_age=max_age, **kwargs) max_age=max_age, **kwargs)
async def forget(self, request, response): async def forget(self, request: web.Request, response: web.StreamResponse) -> None:
response.del_cookie(self._cookie_name) response.del_cookie(self._cookie_name)

View File

@ -2,12 +2,17 @@
""" """
from typing import Optional
from aiohttp import web
from .abc import AbstractIdentityPolicy from .abc import AbstractIdentityPolicy
try: try:
import jwt import jwt
HAS_JWT = True
except ImportError: # pragma: no cover except ImportError: # pragma: no cover
jwt = None HAS_JWT = False
AUTH_HEADER_NAME = 'Authorization' AUTH_HEADER_NAME = 'Authorization'
@ -15,21 +20,21 @@ AUTH_SCHEME = 'Bearer '
class JWTIdentityPolicy(AbstractIdentityPolicy): class JWTIdentityPolicy(AbstractIdentityPolicy):
def __init__(self, secret, algorithm='HS256'): def __init__(self, secret: str, algorithm: str = 'HS256'):
if jwt is None: if not HAS_JWT:
raise RuntimeError('Please install `PyJWT`') raise RuntimeError('Please install `PyJWT`')
self.secret = secret self.secret = secret
self.algorithm = algorithm self.algorithm = algorithm
async def identify(self, request): async def identify(self, request: web.Request) -> Optional[str]:
header_identity = request.headers.get(AUTH_HEADER_NAME) header_identity = request.headers.get(AUTH_HEADER_NAME)
if header_identity is None: if header_identity is None:
return return None
if not header_identity.startswith(AUTH_SCHEME): if not header_identity.startswith(AUTH_SCHEME):
raise ValueError('Invalid authorization scheme. ' + raise ValueError('Invalid authorization scheme. ' +
'Should be `Bearer <token>`') 'Should be `{}<token>`'.format(AUTH_SCHEME))
token = header_identity.split(' ')[1].strip() token = header_identity.split(' ')[1].strip()
@ -38,8 +43,9 @@ class JWTIdentityPolicy(AbstractIdentityPolicy):
algorithms=[self.algorithm]) algorithms=[self.algorithm])
return identity return identity
async def remember(self, *args, **kwargs): # pragma: no cover async def remember(self, request: web.Request, response: web.StreamResponse,
identity: str, **kwargs: None) -> None:
pass pass
async def forget(self, request, response): # pragma: no cover async def forget(self, request: web.Request, response: web.StreamResponse) -> None:
pass pass

View File

View File

@ -4,6 +4,7 @@ aiohttp_session.setup() should be called on application initialization
to configure aiohttp_session properly. to configure aiohttp_session properly.
""" """
from aiohttp import web
try: try:
from aiohttp_session import get_session from aiohttp_session import get_session
HAS_AIOHTTP_SESSION = True HAS_AIOHTTP_SESSION = True
@ -15,21 +16,22 @@ from .abc import AbstractIdentityPolicy
class SessionIdentityPolicy(AbstractIdentityPolicy): class SessionIdentityPolicy(AbstractIdentityPolicy):
def __init__(self, session_key='AIOHTTP_SECURITY'): def __init__(self, session_key: str = 'AIOHTTP_SECURITY'):
self._session_key = session_key self._session_key = session_key
if not HAS_AIOHTTP_SESSION: # pragma: no cover if not HAS_AIOHTTP_SESSION: # pragma: no cover
raise ImportError( raise ImportError(
'SessionIdentityPolicy requires `aiohttp_session`') 'SessionIdentityPolicy requires `aiohttp_session`')
async def identify(self, request): async def identify(self, request: web.Request) -> str:
session = await get_session(request) session = await get_session(request)
return session.get(self._session_key) return session.get(self._session_key)
async def remember(self, request, response, identity, **kwargs): async def remember(self, request: web.Request, response: web.StreamResponse,
identity: str, **kwargs: None) -> None:
session = await get_session(request) session = await get_session(request)
session[self._session_key] = identity session[self._session_key] = identity
async def forget(self, request, response): async def forget(self, request: web.Request, response: web.StreamResponse) -> None:
session = await get_session(request) session = await get_session(request)
session.pop(self._session_key, None) session.pop(self._session_key, None)

0
demo/__init__.py Normal file
View File

View File

View File

@ -1,5 +1,7 @@
import sqlalchemy as sa from enum import Enum
from typing import Any, Optional, Union
import sqlalchemy as sa
from aiohttp_security.abc import AbstractAuthorizationPolicy from aiohttp_security.abc import AbstractAuthorizationPolicy
from passlib.hash import sha256_crypt from passlib.hash import sha256_crypt
@ -7,13 +9,13 @@ from . import db
class DBAuthorizationPolicy(AbstractAuthorizationPolicy): class DBAuthorizationPolicy(AbstractAuthorizationPolicy):
def __init__(self, dbengine): def __init__(self, dbengine: Any):
self.dbengine = dbengine self.dbengine = dbengine
async def authorized_userid(self, identity): async def authorized_userid(self, identity: str) -> Optional[str]:
async with self.dbengine.acquire() as conn: async with self.dbengine.acquire() as conn:
where = sa.and_(db.users.c.login == identity, where = sa.and_(db.users.c.login == identity,
sa.not_(db.users.c.disabled)) sa.not_(db.users.c.disabled)) # type: ignore[no-untyped-call]
query = db.users.count().where(where) query = db.users.count().where(where)
ret = await conn.scalar(query) ret = await conn.scalar(query)
if ret: if ret:
@ -21,13 +23,11 @@ class DBAuthorizationPolicy(AbstractAuthorizationPolicy):
else: else:
return None return None
async def permits(self, identity, permission, context=None): async def permits(self, identity: str, permission: Union[str, Enum],
if identity is None: context: None = None) -> bool:
return False
async with self.dbengine.acquire() as conn: async with self.dbengine.acquire() as conn:
where = sa.and_(db.users.c.login == identity, where = sa.and_(db.users.c.login == identity,
sa.not_(db.users.c.disabled)) sa.not_(db.users.c.disabled)) # type: ignore[no-untyped-call]
query = db.users.select().where(where) query = db.users.select().where(where)
ret = await conn.execute(query) ret = await conn.execute(query)
user = await ret.fetchone() user = await ret.fetchone()
@ -49,14 +49,14 @@ class DBAuthorizationPolicy(AbstractAuthorizationPolicy):
return False return False
async def check_credentials(db_engine, username, password): async def check_credentials(db_engine: Any, username: str, password: str) -> bool:
async with db_engine.acquire() as conn: async with db_engine.acquire() as conn:
where = sa.and_(db.users.c.login == username, where = sa.and_(db.users.c.login == username,
sa.not_(db.users.c.disabled)) sa.not_(db.users.c.disabled)) # type: ignore[no-untyped-call]
query = db.users.select().where(where) query = db.users.select().where(where)
ret = await conn.execute(query) ret = await conn.execute(query)
user = await ret.fetchone() user = await ret.fetchone()
if user is not None: if user is not None:
hashed = user[2] hashed = user[2]
return sha256_crypt.verify(password, hashed) return sha256_crypt.verify(password, hashed) # type: ignore[no-any-return]
return False return False

View File

@ -1,4 +1,5 @@
from textwrap import dedent from textwrap import dedent
from typing import NoReturn
from aiohttp import web from aiohttp import web
@ -27,7 +28,7 @@ class Web(object):
</body> </body>
""") """)
async def index(self, request): async def index(self, request: web.Request) -> web.Response:
username = await authorized_userid(request) username = await authorized_userid(request)
if username: if username:
template = self.index_template.format( template = self.index_template.format(
@ -37,37 +38,41 @@ class Web(object):
response = web.Response(body=template.encode()) response = web.Response(body=template.encode())
return response return response
async def login(self, request): async def login(self, request: web.Request) -> NoReturn:
response = web.HTTPFound('/') invalid_resp = web.HTTPUnauthorized(body=b'Invalid username/password combination')
form = await request.post() form = await request.post()
login = form.get('login') login = form.get('login')
password = form.get('password') password = form.get('password')
db_engine = request.app.db_engine db_engine = request.app['db_engine']
if not (isinstance(login, str) and isinstance(password, str)):
raise invalid_resp
if await check_credentials(db_engine, login, password): if await check_credentials(db_engine, login, password):
response = web.HTTPFound('/')
await remember(request, response, login) await remember(request, response, login)
raise response raise response
raise web.HTTPUnauthorized( raise invalid_resp
body=b'Invalid username/password combination')
async def logout(self, request): async def logout(self, request: web.Request) -> web.Response:
await check_authorized(request) await check_authorized(request)
response = web.Response(body=b'You have been logged out') response = web.Response(body=b'You have been logged out')
await forget(request, response) await forget(request, response)
return response return response
async def internal_page(self, request): async def internal_page(self, request: web.Request) -> web.Response:
await check_permission(request, 'public') await check_permission(request, 'public')
response = web.Response( response = web.Response(
body=b'This page is visible for all registered users') body=b'This page is visible for all registered users')
return response return response
async def protected_page(self, request): async def protected_page(self, request: web.Request) -> web.Response:
await check_permission(request, 'protected') await check_permission(request, 'protected')
response = web.Response(body=b'You are on protected page') response = web.Response(body=b'You are on protected page')
return response return response
def configure(self, app): def configure(self, app: web.Application) -> None:
router = app.router router = app.router
router.add_route('GET', '/', self.index, name='index') router.add_route('GET', '/', self.index, name='index')
router.add_route('POST', '/login', self.login, name='login') router.add_route('POST', '/login', self.login, name='login')

View File

@ -1,4 +1,5 @@
import asyncio import asyncio
from typing import Any, Tuple
from aiohttp import web from aiohttp import web
from aiohttp_session import setup as setup_session from aiohttp_session import setup as setup_session
@ -13,14 +14,14 @@ from demo.database_auth.db_auth import DBAuthorizationPolicy
from demo.database_auth.handlers import Web from demo.database_auth.handlers import Web
async def init(loop): async def init(loop: asyncio.AbstractEventLoop) -> Tuple[Any, ...]:
redis_pool = await create_pool(('localhost', 6379)) redis_pool = await create_pool(('localhost', 6379))
db_engine = await create_engine(user='aiohttp_security', db_engine = await create_engine(user='aiohttp_security',
password='aiohttp_security', password='aiohttp_security',
database='aiohttp_security', database='aiohttp_security',
host='127.0.0.1') host='127.0.0.1')
app = web.Application() app = web.Application()
app.db_engine = db_engine app['db_engine'] = db_engine
setup_session(app, RedisStorage(redis_pool)) setup_session(app, RedisStorage(redis_pool))
setup_security(app, setup_security(app,
SessionIdentityPolicy(), SessionIdentityPolicy(),
@ -35,7 +36,7 @@ async def init(loop):
return srv, app, handler return srv, app, handler
async def finalize(srv, app, handler): async def finalize(srv: Any, app: Any, handler: Any) -> None:
sock = srv.sockets[0] sock = srv.sockets[0]
app.loop.remove_reader(sock.fileno()) app.loop.remove_reader(sock.fileno())
sock.close() sock.close()
@ -46,7 +47,7 @@ async def finalize(srv, app, handler):
await app.finish() await app.finish()
def main(): def main() -> None:
loop = asyncio.get_event_loop() loop = asyncio.get_event_loop()
srv, app, handler = loop.run_until_complete(init(loop)) srv, app, handler = loop.run_until_complete(init(loop))
try: try:

View File

View File

@ -1,20 +1,25 @@
from enum import Enum
from typing import Dict, Optional, Union
from aiohttp_security.abc import AbstractAuthorizationPolicy from aiohttp_security.abc import AbstractAuthorizationPolicy
from .users import User
class DictionaryAuthorizationPolicy(AbstractAuthorizationPolicy): class DictionaryAuthorizationPolicy(AbstractAuthorizationPolicy):
def __init__(self, user_map): def __init__(self, user_map: Dict[str, User]):
super().__init__() super().__init__()
self.user_map = user_map self.user_map = user_map
async def authorized_userid(self, identity): async def authorized_userid(self, identity: str) -> Optional[str]:
"""Retrieve authorized user id. """Retrieve authorized user id.
Return the user_id of the user identified by the identity Return the user_id of the user identified by the identity
or 'None' if no user exists related to the identity. or 'None' if no user exists related to the identity.
""" """
if identity in self.user_map: return identity if identity in self.user_map else None
return identity
async def permits(self, identity, permission, context=None): async def permits(self, identity: str, permission: Union[str, Enum],
context: None = None) -> bool:
"""Check user permissions. """Check user permissions.
Return True if the identity is allowed the permission in the Return True if the identity is allowed the permission in the
current context, else return False. current context, else return False.
@ -26,7 +31,7 @@ class DictionaryAuthorizationPolicy(AbstractAuthorizationPolicy):
return permission in user.permissions return permission in user.permissions
async def check_credentials(user_map, username, password): async def check_credentials(user_map: Dict[str, User], username: str, password: str) -> bool:
user = user_map.get(username) user = user_map.get(username)
if not user: if not user:
return False return False

View File

@ -1,4 +1,5 @@
from textwrap import dedent from textwrap import dedent
from typing import Dict, NoReturn
from aiohttp import web from aiohttp import web
@ -8,6 +9,7 @@ from aiohttp_security import (
) )
from .authz import check_credentials from .authz import check_credentials
from .users import User
index_template = dedent(""" index_template = dedent("""
@ -27,7 +29,7 @@ index_template = dedent("""
""") """)
async def index(request): async def index(request: web.Request) -> web.Response:
username = await authorized_userid(request) username = await authorized_userid(request)
if username: if username:
template = index_template.format( template = index_template.format(
@ -40,22 +42,26 @@ async def index(request):
) )
async def login(request): async def login(request: web.Request) -> NoReturn:
response = web.HTTPFound('/') user_map: Dict[str, User] = request.app['user_map']
invalid_response = web.HTTPUnauthorized(body='Invalid username / password combination')
form = await request.post() form = await request.post()
username = form.get('username') username = form.get('username')
password = form.get('password') password = form.get('password')
verified = await check_credentials( if not (isinstance(username, str) and isinstance(password, str)):
request.app.user_map, username, password) raise invalid_response
verified = await check_credentials(user_map, username, password)
if verified: if verified:
response = web.HTTPFound('/')
await remember(request, response, username) await remember(request, response, username)
return response raise response
return web.HTTPUnauthorized(body='Invalid username / password combination') raise invalid_response
async def logout(request): async def logout(request: web.Request) -> web.Response:
await check_authorized(request) await check_authorized(request)
response = web.Response( response = web.Response(
text='You have been logged out', text='You have been logged out',
@ -65,7 +71,7 @@ async def logout(request):
return response return response
async def internal_page(request): async def internal_page(request: web.Request) -> web.Response:
await check_permission(request, 'public') await check_permission(request, 'public')
response = web.Response( response = web.Response(
text='This page is visible for all registered users', text='This page is visible for all registered users',
@ -74,7 +80,7 @@ async def internal_page(request):
return response return response
async def protected_page(request): async def protected_page(request: web.Request) -> web.Response:
await check_permission(request, 'protected') await check_permission(request, 'protected')
response = web.Response( response = web.Response(
text='You are on protected page', text='You are on protected page',
@ -83,7 +89,7 @@ async def protected_page(request):
return response return response
def configure_handlers(app): def configure_handlers(app: web.Application) -> None:
router = app.router router = app.router
router.add_get('/', index, name='index') router.add_get('/', index, name='index')
router.add_post('/login', login, name='login') router.add_post('/login', login, name='login')

View File

@ -11,9 +11,9 @@ from demo.dictionary_auth.handlers import configure_handlers
from demo.dictionary_auth.users import user_map from demo.dictionary_auth.users import user_map
def make_app(): def make_app() -> web.Application:
app = web.Application() app = web.Application()
app.user_map = user_map app['user_map'] = user_map
configure_handlers(app) configure_handlers(app)
# secret_key must be 32 url-safe base64-encoded bytes # secret_key must be 32 url-safe base64-encoded bytes

View File

@ -1,6 +1,11 @@
from collections import namedtuple from typing import NamedTuple, Tuple
class User(NamedTuple):
username: str
password: str
permissions: Tuple[str, ...]
User = namedtuple('User', ['username', 'password', 'permissions'])
user_map = { user_map = {
user.username: user for user in [ user.username: user for user in [

View File

@ -1,3 +1,6 @@
from enum import Enum
from typing import NoReturn, Optional, Union
from aiohttp import web from aiohttp import web
from aiohttp_session import SimpleCookieStorage, session_middleware from aiohttp_session import SimpleCookieStorage, session_middleware
from aiohttp_security import check_permission, \ from aiohttp_security import check_permission, \
@ -11,15 +14,15 @@ from aiohttp_security.abc import AbstractAuthorizationPolicy
# For more complicated authorization policies see examples # For more complicated authorization policies see examples
# in the 'demo' directory. # in the 'demo' directory.
class SimpleJack_AuthorizationPolicy(AbstractAuthorizationPolicy): class SimpleJack_AuthorizationPolicy(AbstractAuthorizationPolicy):
async def authorized_userid(self, identity): async def authorized_userid(self, identity: str) -> Optional[str]:
"""Retrieve authorized user id. """Retrieve authorized user id.
Return the user_id of the user identified by the identity Return the user_id of the user identified by the identity
or 'None' if no user exists related to the identity. or 'None' if no user exists related to the identity.
""" """
if identity == 'jack': return identity if identity == 'jack' else None
return identity
async def permits(self, identity, permission, context=None): async def permits(self, identity: str, permission: Union[str, Enum],
context: None = None) -> bool:
"""Check user permissions. """Check user permissions.
Return True if the identity is allowed the permission Return True if the identity is allowed the permission
in the current context, else return False. in the current context, else return False.
@ -27,7 +30,7 @@ class SimpleJack_AuthorizationPolicy(AbstractAuthorizationPolicy):
return identity == 'jack' and permission in ('listen',) return identity == 'jack' and permission in ('listen',)
async def handler_root(request): async def handler_root(request: web.Request) -> web.Response:
is_logged = not await is_anonymous(request) is_logged = not await is_anonymous(request)
return web.Response(text='''<html><head></head><body> return web.Response(text='''<html><head></head><body>
Hello, I'm Jack, I'm {logged} logged in.<br /><br /> Hello, I'm Jack, I'm {logged} logged in.<br /><br />
@ -42,29 +45,29 @@ async def handler_root(request):
), content_type='text/html') ), content_type='text/html')
async def handler_login_jack(request): async def handler_login_jack(request: web.Request) -> NoReturn:
redirect_response = web.HTTPFound('/') redirect_response = web.HTTPFound('/')
await remember(request, redirect_response, 'jack') await remember(request, redirect_response, 'jack')
raise redirect_response raise redirect_response
async def handler_logout(request): async def handler_logout(request: web.Request) -> NoReturn:
redirect_response = web.HTTPFound('/') redirect_response = web.HTTPFound('/')
await forget(request, redirect_response) await forget(request, redirect_response)
raise redirect_response raise redirect_response
async def handler_listen(request): async def handler_listen(request: web.Request) -> web.Response:
await check_permission(request, 'listen') await check_permission(request, 'listen')
return web.Response(body="I can listen!") return web.Response(body="I can listen!")
async def handler_speak(request): async def handler_speak(request: web.Request) -> web.Response:
await check_permission(request, 'speak') await check_permission(request, 'speak')
return web.Response(body="I can speak!") return web.Response(body="I can speak!")
async def make_app(): async def make_app() -> web.Application:
# #
# WARNING!!! # WARNING!!!
# Never use SimpleCookieStorage on production!!! # Never use SimpleCookieStorage on production!!!

0
tests/__init__.py Normal file
View File

View File

@ -67,7 +67,7 @@ async def test_identify_broken_scheme(loop, make_token, aiohttp_client):
try: try:
await policy.identify(request) await policy.identify(request)
except ValueError as exc: except ValueError as exc:
raise web.HTTPBadRequest(reason=exc) raise web.HTTPBadRequest(reason=str(exc))
return web.Response() return web.Response()