From b3b536746060d46b358595814b64d3b528975731 Mon Sep 17 00:00:00 2001 From: Sam Bull Date: Fri, 18 Dec 2020 17:58:38 +0000 Subject: [PATCH] Add type annotations. --- .mypy.ini | 37 ++++++++++++++++ aiohttp_security/abc.py | 16 ++++--- aiohttp_security/api.py | 64 ++++++++++++++++------------ aiohttp_security/cookies_identity.py | 22 ++++++---- aiohttp_security/jwt_identity.py | 22 ++++++---- aiohttp_security/py.typed | 0 aiohttp_security/session_identity.py | 10 +++-- demo/__init__.py | 0 demo/database_auth/__init__.py | 0 demo/database_auth/db_auth.py | 24 +++++------ demo/database_auth/handlers.py | 25 ++++++----- demo/database_auth/main.py | 9 ++-- demo/dictionary_auth/__init__.py | 0 demo/dictionary_auth/authz.py | 17 +++++--- demo/dictionary_auth/handlers.py | 28 +++++++----- demo/dictionary_auth/main.py | 4 +- demo/dictionary_auth/users.py | 9 +++- demo/simple_example_auth.py | 23 +++++----- tests/__init__.py | 0 tests/test_jwt_identity.py | 2 +- 20 files changed, 200 insertions(+), 112 deletions(-) create mode 100644 .mypy.ini create mode 100644 aiohttp_security/py.typed create mode 100644 demo/__init__.py create mode 100644 demo/database_auth/__init__.py create mode 100644 demo/dictionary_auth/__init__.py create mode 100644 tests/__init__.py diff --git a/.mypy.ini b/.mypy.ini new file mode 100644 index 0000000..700e458 --- /dev/null +++ b/.mypy.ini @@ -0,0 +1,37 @@ +[mypy] +files = aiohttp_security, demo, tests +check_untyped_defs = True +follow_imports_for_stubs = True +disallow_any_decorated = True +disallow_any_generics = True +disallow_incomplete_defs = True +disallow_subclassing_any = True +disallow_untyped_calls = True +disallow_untyped_decorators = True +disallow_untyped_defs = True +implicit_reexport = False +no_implicit_optional = True +show_error_codes = True +strict_equality = True +warn_incomplete_stub = True +warn_redundant_casts = True +warn_unreachable = True +warn_unused_ignores = True +disallow_any_unimported = True +warn_return_any = True + +[mypy-aiohttp_security.abc.*] +disallow_any_decorated = False + +[mypy-tests.*] +disallow_any_decorated = False +disallow_untyped_defs = False + +[mypy-aiopg.*] +ignore_missing_imports = True + +[mypy-aioredis.*] +ignore_missing_imports = True + +[mypy-passlib.*] +ignore_missing_imports = True diff --git a/aiohttp_security/abc.py b/aiohttp_security/abc.py index 862abd8..0b63bb5 100644 --- a/aiohttp_security/abc.py +++ b/aiohttp_security/abc.py @@ -1,4 +1,8 @@ import abc +from enum import Enum +from typing import Any, Optional, Union + +from aiohttp import web # see http://plope.com/pyramid_auth_design_api_postmortem @@ -6,13 +10,14 @@ import abc class AbstractIdentityPolicy(metaclass=abc.ABCMeta): @abc.abstractmethod - async def identify(self, request): + async def identify(self, request: web.Request) -> Optional[str]: """Return the claimed identity of the user associated request or ``None`` if no identity can be found associated with the request.""" pass @abc.abstractmethod - async def remember(self, request, response, identity, **kwargs): + async def remember(self, request: web.Request, response: web.StreamResponse, + identity: str, **kwargs: Any) -> None: """Remember identity. Modify response object by filling it's headers with remembered user. @@ -23,7 +28,7 @@ class AbstractIdentityPolicy(metaclass=abc.ABCMeta): pass @abc.abstractmethod - async def forget(self, request, response): + async def forget(self, request: web.Request, response: web.StreamResponse) -> None: """ Modify response which can be used to 'forget' the current identity on subsequent requests.""" pass @@ -32,7 +37,8 @@ class AbstractIdentityPolicy(metaclass=abc.ABCMeta): class AbstractAuthorizationPolicy(metaclass=abc.ABCMeta): @abc.abstractmethod - async def permits(self, identity, permission, context=None): + async def permits(self, identity: str, permission: Union[str, Enum], + context: Any = None) -> bool: """Check user permissions. Return True if the identity is allowed the permission in the @@ -41,7 +47,7 @@ class AbstractAuthorizationPolicy(metaclass=abc.ABCMeta): pass @abc.abstractmethod - async def authorized_userid(self, identity): + async def authorized_userid(self, identity: str) -> Optional[str]: """Retrieve authorized user id. Return the user_id of the user identified by the identity diff --git a/aiohttp_security/api.py b/aiohttp_security/api.py index d6a2dd5..ccf1798 100644 --- a/aiohttp_security/api.py +++ b/aiohttp_security/api.py @@ -1,15 +1,23 @@ import enum import warnings -from aiohttp import web -from aiohttp_security.abc import (AbstractIdentityPolicy, - AbstractAuthorizationPolicy) from functools import wraps +from typing import Any, Callable, Optional, TypeVar, Union + +from aiohttp import web +from aiohttp_security.abc import AbstractAuthorizationPolicy, AbstractIdentityPolicy IDENTITY_KEY = 'aiohttp_security_identity_policy' AUTZ_KEY = 'aiohttp_security_autz_policy' +# _AIP/_AAP are shorthand for Optional[policy] when we retrieve from request. +_AAP = Optional[AbstractAuthorizationPolicy] +_AIP = Optional[AbstractIdentityPolicy] +_Handler = TypeVar('_Handler', bound=Union[Callable[[web.Request], Any], + Callable[[object, web.Request], Any]]) -async def remember(request, response, identity, **kwargs): + +async def remember(request: web.Request, response: web.StreamResponse, + identity: str, **kwargs: Any) -> None: """Remember identity into response. The action is performed by identity_policy.remember() @@ -30,7 +38,7 @@ async def remember(request, response, identity, **kwargs): await identity_policy.remember(request, response, identity, **kwargs) -async def forget(request, response): +async def forget(request: web.Request, response: web.StreamResponse) -> None: """Forget previously remembered identity. Usually it clears cookie or server-side storage to forget user @@ -47,9 +55,9 @@ async def forget(request, response): await identity_policy.forget(request, response) -async def authorized_userid(request): - identity_policy = request.config_dict.get(IDENTITY_KEY) - autz_policy = request.config_dict.get(AUTZ_KEY) +async def authorized_userid(request: web.Request) -> Optional[str]: + identity_policy: _AIP = request.config_dict.get(IDENTITY_KEY) + autz_policy: _AAP = request.config_dict.get(AUTZ_KEY) if identity_policy is None or autz_policy is None: return None identity = await identity_policy.identify(request) @@ -59,20 +67,21 @@ async def authorized_userid(request): return user_id -async def permits(request, permission, context=None): +async def permits(request: web.Request, permission: Union[str, enum.Enum], + context: Any = None) -> bool: assert isinstance(permission, (str, enum.Enum)), permission assert permission - identity_policy = request.config_dict.get(IDENTITY_KEY) - autz_policy = request.config_dict.get(AUTZ_KEY) + identity_policy: _AIP = request.config_dict.get(IDENTITY_KEY) + autz_policy: _AAP = request.config_dict.get(AUTZ_KEY) if identity_policy is None or autz_policy is None: return True identity = await identity_policy.identify(request) - # non-registered user still may has some permissions + # non-registered user still may have some permissions access = await autz_policy.permits(identity, permission, context) return access -async def is_anonymous(request): +async def is_anonymous(request: web.Request) -> bool: """Check if user is anonymous. User is considered anonymous if there is not identity @@ -87,7 +96,7 @@ async def is_anonymous(request): return False -async def check_authorized(request): +async def check_authorized(request: web.Request) -> str: """Checker that raises HTTPUnauthorized for anonymous users. """ userid = await authorized_userid(request) @@ -96,31 +105,32 @@ async def check_authorized(request): return userid -def login_required(fn): +def login_required(fn: _Handler) -> _Handler: """Decorator that restrict access only for authorized users. User is considered authorized if authorized_userid returns some value. """ @wraps(fn) - async def wrapped(*args, **kwargs): + async def wrapped(*args: Union[object, web.Request]) -> Any: request = args[-1] - if not isinstance(request, web.BaseRequest): + if not isinstance(request, web.Request): msg = ("Incorrect decorator usage. " "Expecting `def handler(request)` " "or `def handler(self, request)`.") raise RuntimeError(msg) await check_authorized(request) - return await fn(*args, **kwargs) + return await fn(*args) # type: ignore[arg-type] warnings.warn("login_required decorator is deprecated, " "use check_authorized instead", DeprecationWarning) - return wrapped + return wrapped # type: ignore[return-value] -async def check_permission(request, permission, context=None): +async def check_permission(request: web.Request, permission: Union[str, enum.Enum], + context: Any = None) -> None: """Checker that passes only to authoraised users with given permission. If user is not authorized - raises HTTPUnauthorized, @@ -134,10 +144,7 @@ async def check_permission(request, permission, context=None): raise web.HTTPForbidden() -def has_permission( - permission, - context=None, -): +def has_permission(permission: Union[str, enum.Enum], context: Any = None): # type: ignore """Decorator that restricts access only for authorized users with correct permissions. @@ -145,11 +152,11 @@ def has_permission( if user is authorized and does not have permission - raises HTTPForbidden. """ - def wrapper(fn): + def wrapper(fn): # type: ignore @wraps(fn) - async def wrapped(*args, **kwargs): + async def wrapped(*args, **kwargs): # type: ignore request = args[-1] - if not isinstance(request, web.BaseRequest): + if not isinstance(request, web.Request): msg = ("Incorrect decorator usage. " "Expecting `def handler(request)` " "or `def handler(self, request)`.") @@ -166,7 +173,8 @@ def has_permission( return wrapper -def setup(app, identity_policy, autz_policy): +def setup(app: web.Application, identity_policy: AbstractIdentityPolicy, + autz_policy: AbstractAuthorizationPolicy) -> None: assert isinstance(identity_policy, AbstractIdentityPolicy), identity_policy assert isinstance(autz_policy, AbstractAuthorizationPolicy), autz_policy diff --git a/aiohttp_security/cookies_identity.py b/aiohttp_security/cookies_identity.py index 3822ef9..f20c422 100644 --- a/aiohttp_security/cookies_identity.py +++ b/aiohttp_security/cookies_identity.py @@ -5,28 +5,32 @@ more handy. """ +from aiohttp import web +from typing import Any, NewType, Optional, Union, cast + from .abc import AbstractIdentityPolicy - -sentinel = object() +_Sentinel = NewType('_Sentinel', object) +sentinel = _Sentinel(object()) class CookiesIdentityPolicy(AbstractIdentityPolicy): - def __init__(self): + def __init__(self) -> None: self._cookie_name = 'AIOHTTP_SECURITY' self._max_age = 30 * 24 * 3600 - async def identify(self, request): - identity = request.cookies.get(self._cookie_name) - return identity + async def identify(self, request: web.Request) -> Optional[str]: + return request.cookies.get(self._cookie_name) - async def remember(self, request, response, identity, max_age=sentinel, - **kwargs): + async def remember(self, request: web.Request, response: web.StreamResponse, + identity: str, max_age: Union[_Sentinel, Optional[int]] = sentinel, + **kwargs: Any) -> None: if max_age is sentinel: max_age = self._max_age + max_age = cast(Optional[int], max_age) response.set_cookie(self._cookie_name, identity, max_age=max_age, **kwargs) - async def forget(self, request, response): + async def forget(self, request: web.Request, response: web.StreamResponse) -> None: response.del_cookie(self._cookie_name) diff --git a/aiohttp_security/jwt_identity.py b/aiohttp_security/jwt_identity.py index bd6ffe7..e456007 100644 --- a/aiohttp_security/jwt_identity.py +++ b/aiohttp_security/jwt_identity.py @@ -2,12 +2,17 @@ """ +from typing import Optional + +from aiohttp import web + from .abc import AbstractIdentityPolicy try: import jwt + HAS_JWT = True except ImportError: # pragma: no cover - jwt = None + HAS_JWT = False AUTH_HEADER_NAME = 'Authorization' @@ -15,21 +20,21 @@ AUTH_SCHEME = 'Bearer ' class JWTIdentityPolicy(AbstractIdentityPolicy): - def __init__(self, secret, algorithm='HS256'): - if jwt is None: + def __init__(self, secret: str, algorithm: str = 'HS256'): + if not HAS_JWT: raise RuntimeError('Please install `PyJWT`') self.secret = secret self.algorithm = algorithm - async def identify(self, request): + async def identify(self, request: web.Request) -> Optional[str]: header_identity = request.headers.get(AUTH_HEADER_NAME) if header_identity is None: - return + return None if not header_identity.startswith(AUTH_SCHEME): raise ValueError('Invalid authorization scheme. ' + - 'Should be `Bearer `') + 'Should be `{}`'.format(AUTH_SCHEME)) token = header_identity.split(' ')[1].strip() @@ -38,8 +43,9 @@ class JWTIdentityPolicy(AbstractIdentityPolicy): algorithms=[self.algorithm]) return identity - async def remember(self, *args, **kwargs): # pragma: no cover + async def remember(self, request: web.Request, response: web.StreamResponse, + identity: str, **kwargs: None) -> None: pass - async def forget(self, request, response): # pragma: no cover + async def forget(self, request: web.Request, response: web.StreamResponse) -> None: pass diff --git a/aiohttp_security/py.typed b/aiohttp_security/py.typed new file mode 100644 index 0000000..e69de29 diff --git a/aiohttp_security/session_identity.py b/aiohttp_security/session_identity.py index 13d197a..33c5ef2 100644 --- a/aiohttp_security/session_identity.py +++ b/aiohttp_security/session_identity.py @@ -4,6 +4,7 @@ aiohttp_session.setup() should be called on application initialization to configure aiohttp_session properly. """ +from aiohttp import web try: from aiohttp_session import get_session HAS_AIOHTTP_SESSION = True @@ -15,21 +16,22 @@ from .abc import AbstractIdentityPolicy class SessionIdentityPolicy(AbstractIdentityPolicy): - def __init__(self, session_key='AIOHTTP_SECURITY'): + def __init__(self, session_key: str = 'AIOHTTP_SECURITY'): self._session_key = session_key if not HAS_AIOHTTP_SESSION: # pragma: no cover raise ImportError( 'SessionIdentityPolicy requires `aiohttp_session`') - async def identify(self, request): + async def identify(self, request: web.Request) -> str: session = await get_session(request) return session.get(self._session_key) - async def remember(self, request, response, identity, **kwargs): + async def remember(self, request: web.Request, response: web.StreamResponse, + identity: str, **kwargs: None) -> None: session = await get_session(request) session[self._session_key] = identity - async def forget(self, request, response): + async def forget(self, request: web.Request, response: web.StreamResponse) -> None: session = await get_session(request) session.pop(self._session_key, None) diff --git a/demo/__init__.py b/demo/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/demo/database_auth/__init__.py b/demo/database_auth/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/demo/database_auth/db_auth.py b/demo/database_auth/db_auth.py index 72b1ad2..5eaf04a 100644 --- a/demo/database_auth/db_auth.py +++ b/demo/database_auth/db_auth.py @@ -1,5 +1,7 @@ -import sqlalchemy as sa +from enum import Enum +from typing import Any, Optional, Union +import sqlalchemy as sa from aiohttp_security.abc import AbstractAuthorizationPolicy from passlib.hash import sha256_crypt @@ -7,13 +9,13 @@ from . import db class DBAuthorizationPolicy(AbstractAuthorizationPolicy): - def __init__(self, dbengine): + def __init__(self, dbengine: Any): self.dbengine = dbengine - async def authorized_userid(self, identity): + async def authorized_userid(self, identity: str) -> Optional[str]: async with self.dbengine.acquire() as conn: where = sa.and_(db.users.c.login == identity, - sa.not_(db.users.c.disabled)) + sa.not_(db.users.c.disabled)) # type: ignore[no-untyped-call] query = db.users.count().where(where) ret = await conn.scalar(query) if ret: @@ -21,13 +23,11 @@ class DBAuthorizationPolicy(AbstractAuthorizationPolicy): else: return None - async def permits(self, identity, permission, context=None): - if identity is None: - return False - + async def permits(self, identity: str, permission: Union[str, Enum], + context: None = None) -> bool: async with self.dbengine.acquire() as conn: where = sa.and_(db.users.c.login == identity, - sa.not_(db.users.c.disabled)) + sa.not_(db.users.c.disabled)) # type: ignore[no-untyped-call] query = db.users.select().where(where) ret = await conn.execute(query) user = await ret.fetchone() @@ -49,14 +49,14 @@ class DBAuthorizationPolicy(AbstractAuthorizationPolicy): return False -async def check_credentials(db_engine, username, password): +async def check_credentials(db_engine: Any, username: str, password: str) -> bool: async with db_engine.acquire() as conn: where = sa.and_(db.users.c.login == username, - sa.not_(db.users.c.disabled)) + sa.not_(db.users.c.disabled)) # type: ignore[no-untyped-call] query = db.users.select().where(where) ret = await conn.execute(query) user = await ret.fetchone() if user is not None: hashed = user[2] - return sha256_crypt.verify(password, hashed) + return sha256_crypt.verify(password, hashed) # type: ignore[no-any-return] return False diff --git a/demo/database_auth/handlers.py b/demo/database_auth/handlers.py index 0cc9484..19a2bf9 100644 --- a/demo/database_auth/handlers.py +++ b/demo/database_auth/handlers.py @@ -1,4 +1,5 @@ from textwrap import dedent +from typing import NoReturn from aiohttp import web @@ -27,7 +28,7 @@ class Web(object): """) - async def index(self, request): + async def index(self, request: web.Request) -> web.Response: username = await authorized_userid(request) if username: template = self.index_template.format( @@ -37,37 +38,41 @@ class Web(object): response = web.Response(body=template.encode()) return response - async def login(self, request): - response = web.HTTPFound('/') + async def login(self, request: web.Request) -> NoReturn: + invalid_resp = web.HTTPUnauthorized(body=b'Invalid username/password combination') form = await request.post() login = form.get('login') password = form.get('password') - db_engine = request.app.db_engine + db_engine = request.app['db_engine'] + + if not (isinstance(login, str) and isinstance(password, str)): + raise invalid_resp + if await check_credentials(db_engine, login, password): + response = web.HTTPFound('/') await remember(request, response, login) raise response - raise web.HTTPUnauthorized( - body=b'Invalid username/password combination') + raise invalid_resp - async def logout(self, request): + async def logout(self, request: web.Request) -> web.Response: await check_authorized(request) response = web.Response(body=b'You have been logged out') await forget(request, response) return response - async def internal_page(self, request): + async def internal_page(self, request: web.Request) -> web.Response: await check_permission(request, 'public') response = web.Response( body=b'This page is visible for all registered users') return response - async def protected_page(self, request): + async def protected_page(self, request: web.Request) -> web.Response: await check_permission(request, 'protected') response = web.Response(body=b'You are on protected page') return response - def configure(self, app): + def configure(self, app: web.Application) -> None: router = app.router router.add_route('GET', '/', self.index, name='index') router.add_route('POST', '/login', self.login, name='login') diff --git a/demo/database_auth/main.py b/demo/database_auth/main.py index 822df44..b6b81d8 100644 --- a/demo/database_auth/main.py +++ b/demo/database_auth/main.py @@ -1,4 +1,5 @@ import asyncio +from typing import Any, Tuple from aiohttp import web from aiohttp_session import setup as setup_session @@ -13,14 +14,14 @@ from demo.database_auth.db_auth import DBAuthorizationPolicy from demo.database_auth.handlers import Web -async def init(loop): +async def init(loop: asyncio.AbstractEventLoop) -> Tuple[Any, ...]: redis_pool = await create_pool(('localhost', 6379)) db_engine = await create_engine(user='aiohttp_security', password='aiohttp_security', database='aiohttp_security', host='127.0.0.1') app = web.Application() - app.db_engine = db_engine + app['db_engine'] = db_engine setup_session(app, RedisStorage(redis_pool)) setup_security(app, SessionIdentityPolicy(), @@ -35,7 +36,7 @@ async def init(loop): return srv, app, handler -async def finalize(srv, app, handler): +async def finalize(srv: Any, app: Any, handler: Any) -> None: sock = srv.sockets[0] app.loop.remove_reader(sock.fileno()) sock.close() @@ -46,7 +47,7 @@ async def finalize(srv, app, handler): await app.finish() -def main(): +def main() -> None: loop = asyncio.get_event_loop() srv, app, handler = loop.run_until_complete(init(loop)) try: diff --git a/demo/dictionary_auth/__init__.py b/demo/dictionary_auth/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/demo/dictionary_auth/authz.py b/demo/dictionary_auth/authz.py index 0f9baae..57ef52a 100644 --- a/demo/dictionary_auth/authz.py +++ b/demo/dictionary_auth/authz.py @@ -1,20 +1,25 @@ +from enum import Enum +from typing import Dict, Optional, Union + from aiohttp_security.abc import AbstractAuthorizationPolicy +from .users import User + class DictionaryAuthorizationPolicy(AbstractAuthorizationPolicy): - def __init__(self, user_map): + def __init__(self, user_map: Dict[str, User]): super().__init__() self.user_map = user_map - async def authorized_userid(self, identity): + async def authorized_userid(self, identity: str) -> Optional[str]: """Retrieve authorized user id. Return the user_id of the user identified by the identity or 'None' if no user exists related to the identity. """ - if identity in self.user_map: - return identity + return identity if identity in self.user_map else None - async def permits(self, identity, permission, context=None): + async def permits(self, identity: str, permission: Union[str, Enum], + context: None = None) -> bool: """Check user permissions. Return True if the identity is allowed the permission in the current context, else return False. @@ -26,7 +31,7 @@ class DictionaryAuthorizationPolicy(AbstractAuthorizationPolicy): return permission in user.permissions -async def check_credentials(user_map, username, password): +async def check_credentials(user_map: Dict[str, User], username: str, password: str) -> bool: user = user_map.get(username) if not user: return False diff --git a/demo/dictionary_auth/handlers.py b/demo/dictionary_auth/handlers.py index 6c19bab..e558715 100644 --- a/demo/dictionary_auth/handlers.py +++ b/demo/dictionary_auth/handlers.py @@ -1,4 +1,5 @@ from textwrap import dedent +from typing import Dict, NoReturn from aiohttp import web @@ -8,6 +9,7 @@ from aiohttp_security import ( ) from .authz import check_credentials +from .users import User index_template = dedent(""" @@ -27,7 +29,7 @@ index_template = dedent(""" """) -async def index(request): +async def index(request: web.Request) -> web.Response: username = await authorized_userid(request) if username: template = index_template.format( @@ -40,22 +42,26 @@ async def index(request): ) -async def login(request): - response = web.HTTPFound('/') +async def login(request: web.Request) -> NoReturn: + user_map: Dict[str, User] = request.app['user_map'] + invalid_response = web.HTTPUnauthorized(body='Invalid username / password combination') form = await request.post() username = form.get('username') password = form.get('password') - verified = await check_credentials( - request.app.user_map, username, password) + if not (isinstance(username, str) and isinstance(password, str)): + raise invalid_response + + verified = await check_credentials(user_map, username, password) if verified: + response = web.HTTPFound('/') await remember(request, response, username) - return response + raise response - return web.HTTPUnauthorized(body='Invalid username / password combination') + raise invalid_response -async def logout(request): +async def logout(request: web.Request) -> web.Response: await check_authorized(request) response = web.Response( text='You have been logged out', @@ -65,7 +71,7 @@ async def logout(request): return response -async def internal_page(request): +async def internal_page(request: web.Request) -> web.Response: await check_permission(request, 'public') response = web.Response( text='This page is visible for all registered users', @@ -74,7 +80,7 @@ async def internal_page(request): return response -async def protected_page(request): +async def protected_page(request: web.Request) -> web.Response: await check_permission(request, 'protected') response = web.Response( text='You are on protected page', @@ -83,7 +89,7 @@ async def protected_page(request): return response -def configure_handlers(app): +def configure_handlers(app: web.Application) -> None: router = app.router router.add_get('/', index, name='index') router.add_post('/login', login, name='login') diff --git a/demo/dictionary_auth/main.py b/demo/dictionary_auth/main.py index b4fe2b4..0bd5f1a 100644 --- a/demo/dictionary_auth/main.py +++ b/demo/dictionary_auth/main.py @@ -11,9 +11,9 @@ from demo.dictionary_auth.handlers import configure_handlers from demo.dictionary_auth.users import user_map -def make_app(): +def make_app() -> web.Application: app = web.Application() - app.user_map = user_map + app['user_map'] = user_map configure_handlers(app) # secret_key must be 32 url-safe base64-encoded bytes diff --git a/demo/dictionary_auth/users.py b/demo/dictionary_auth/users.py index 967b2bb..65a9aa5 100644 --- a/demo/dictionary_auth/users.py +++ b/demo/dictionary_auth/users.py @@ -1,6 +1,11 @@ -from collections import namedtuple +from typing import NamedTuple, Tuple + + +class User(NamedTuple): + username: str + password: str + permissions: Tuple[str, ...] -User = namedtuple('User', ['username', 'password', 'permissions']) user_map = { user.username: user for user in [ diff --git a/demo/simple_example_auth.py b/demo/simple_example_auth.py index 5b4b82d..a613518 100644 --- a/demo/simple_example_auth.py +++ b/demo/simple_example_auth.py @@ -1,3 +1,6 @@ +from enum import Enum +from typing import NoReturn, Optional, Union + from aiohttp import web from aiohttp_session import SimpleCookieStorage, session_middleware from aiohttp_security import check_permission, \ @@ -11,15 +14,15 @@ from aiohttp_security.abc import AbstractAuthorizationPolicy # For more complicated authorization policies see examples # in the 'demo' directory. class SimpleJack_AuthorizationPolicy(AbstractAuthorizationPolicy): - async def authorized_userid(self, identity): + async def authorized_userid(self, identity: str) -> Optional[str]: """Retrieve authorized user id. Return the user_id of the user identified by the identity or 'None' if no user exists related to the identity. """ - if identity == 'jack': - return identity + return identity if identity == 'jack' else None - async def permits(self, identity, permission, context=None): + async def permits(self, identity: str, permission: Union[str, Enum], + context: None = None) -> bool: """Check user permissions. Return True if the identity is allowed the permission in the current context, else return False. @@ -27,7 +30,7 @@ class SimpleJack_AuthorizationPolicy(AbstractAuthorizationPolicy): return identity == 'jack' and permission in ('listen',) -async def handler_root(request): +async def handler_root(request: web.Request) -> web.Response: is_logged = not await is_anonymous(request) return web.Response(text=''' Hello, I'm Jack, I'm {logged} logged in.

@@ -42,29 +45,29 @@ async def handler_root(request): ), content_type='text/html') -async def handler_login_jack(request): +async def handler_login_jack(request: web.Request) -> NoReturn: redirect_response = web.HTTPFound('/') await remember(request, redirect_response, 'jack') raise redirect_response -async def handler_logout(request): +async def handler_logout(request: web.Request) -> NoReturn: redirect_response = web.HTTPFound('/') await forget(request, redirect_response) raise redirect_response -async def handler_listen(request): +async def handler_listen(request: web.Request) -> web.Response: await check_permission(request, 'listen') return web.Response(body="I can listen!") -async def handler_speak(request): +async def handler_speak(request: web.Request) -> web.Response: await check_permission(request, 'speak') return web.Response(body="I can speak!") -async def make_app(): +async def make_app() -> web.Application: # # WARNING!!! # Never use SimpleCookieStorage on production!!! diff --git a/tests/__init__.py b/tests/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/test_jwt_identity.py b/tests/test_jwt_identity.py index 813deac..a01e815 100644 --- a/tests/test_jwt_identity.py +++ b/tests/test_jwt_identity.py @@ -67,7 +67,7 @@ async def test_identify_broken_scheme(loop, make_token, aiohttp_client): try: await policy.identify(request) except ValueError as exc: - raise web.HTTPBadRequest(reason=exc) + raise web.HTTPBadRequest(reason=str(exc)) return web.Response()