diff --git a/.mypy.ini b/.mypy.ini
new file mode 100644
index 0000000..700e458
--- /dev/null
+++ b/.mypy.ini
@@ -0,0 +1,37 @@
+[mypy]
+files = aiohttp_security, demo, tests
+check_untyped_defs = True
+follow_imports_for_stubs = True
+disallow_any_decorated = True
+disallow_any_generics = True
+disallow_incomplete_defs = True
+disallow_subclassing_any = True
+disallow_untyped_calls = True
+disallow_untyped_decorators = True
+disallow_untyped_defs = True
+implicit_reexport = False
+no_implicit_optional = True
+show_error_codes = True
+strict_equality = True
+warn_incomplete_stub = True
+warn_redundant_casts = True
+warn_unreachable = True
+warn_unused_ignores = True
+disallow_any_unimported = True
+warn_return_any = True
+
+[mypy-aiohttp_security.abc.*]
+disallow_any_decorated = False
+
+[mypy-tests.*]
+disallow_any_decorated = False
+disallow_untyped_defs = False
+
+[mypy-aiopg.*]
+ignore_missing_imports = True
+
+[mypy-aioredis.*]
+ignore_missing_imports = True
+
+[mypy-passlib.*]
+ignore_missing_imports = True
diff --git a/aiohttp_security/abc.py b/aiohttp_security/abc.py
index 862abd8..0b63bb5 100644
--- a/aiohttp_security/abc.py
+++ b/aiohttp_security/abc.py
@@ -1,4 +1,8 @@
import abc
+from enum import Enum
+from typing import Any, Optional, Union
+
+from aiohttp import web
# see http://plope.com/pyramid_auth_design_api_postmortem
@@ -6,13 +10,14 @@ import abc
class AbstractIdentityPolicy(metaclass=abc.ABCMeta):
@abc.abstractmethod
- async def identify(self, request):
+ async def identify(self, request: web.Request) -> Optional[str]:
"""Return the claimed identity of the user associated request or
``None`` if no identity can be found associated with the request."""
pass
@abc.abstractmethod
- async def remember(self, request, response, identity, **kwargs):
+ async def remember(self, request: web.Request, response: web.StreamResponse,
+ identity: str, **kwargs: Any) -> None:
"""Remember identity.
Modify response object by filling it's headers with remembered user.
@@ -23,7 +28,7 @@ class AbstractIdentityPolicy(metaclass=abc.ABCMeta):
pass
@abc.abstractmethod
- async def forget(self, request, response):
+ async def forget(self, request: web.Request, response: web.StreamResponse) -> None:
""" Modify response which can be used to 'forget' the
current identity on subsequent requests."""
pass
@@ -32,7 +37,8 @@ class AbstractIdentityPolicy(metaclass=abc.ABCMeta):
class AbstractAuthorizationPolicy(metaclass=abc.ABCMeta):
@abc.abstractmethod
- async def permits(self, identity, permission, context=None):
+ async def permits(self, identity: str, permission: Union[str, Enum],
+ context: Any = None) -> bool:
"""Check user permissions.
Return True if the identity is allowed the permission in the
@@ -41,7 +47,7 @@ class AbstractAuthorizationPolicy(metaclass=abc.ABCMeta):
pass
@abc.abstractmethod
- async def authorized_userid(self, identity):
+ async def authorized_userid(self, identity: str) -> Optional[str]:
"""Retrieve authorized user id.
Return the user_id of the user identified by the identity
diff --git a/aiohttp_security/api.py b/aiohttp_security/api.py
index d6a2dd5..ccf1798 100644
--- a/aiohttp_security/api.py
+++ b/aiohttp_security/api.py
@@ -1,15 +1,23 @@
import enum
import warnings
-from aiohttp import web
-from aiohttp_security.abc import (AbstractIdentityPolicy,
- AbstractAuthorizationPolicy)
from functools import wraps
+from typing import Any, Callable, Optional, TypeVar, Union
+
+from aiohttp import web
+from aiohttp_security.abc import AbstractAuthorizationPolicy, AbstractIdentityPolicy
IDENTITY_KEY = 'aiohttp_security_identity_policy'
AUTZ_KEY = 'aiohttp_security_autz_policy'
+# _AIP/_AAP are shorthand for Optional[policy] when we retrieve from request.
+_AAP = Optional[AbstractAuthorizationPolicy]
+_AIP = Optional[AbstractIdentityPolicy]
+_Handler = TypeVar('_Handler', bound=Union[Callable[[web.Request], Any],
+ Callable[[object, web.Request], Any]])
-async def remember(request, response, identity, **kwargs):
+
+async def remember(request: web.Request, response: web.StreamResponse,
+ identity: str, **kwargs: Any) -> None:
"""Remember identity into response.
The action is performed by identity_policy.remember()
@@ -30,7 +38,7 @@ async def remember(request, response, identity, **kwargs):
await identity_policy.remember(request, response, identity, **kwargs)
-async def forget(request, response):
+async def forget(request: web.Request, response: web.StreamResponse) -> None:
"""Forget previously remembered identity.
Usually it clears cookie or server-side storage to forget user
@@ -47,9 +55,9 @@ async def forget(request, response):
await identity_policy.forget(request, response)
-async def authorized_userid(request):
- identity_policy = request.config_dict.get(IDENTITY_KEY)
- autz_policy = request.config_dict.get(AUTZ_KEY)
+async def authorized_userid(request: web.Request) -> Optional[str]:
+ identity_policy: _AIP = request.config_dict.get(IDENTITY_KEY)
+ autz_policy: _AAP = request.config_dict.get(AUTZ_KEY)
if identity_policy is None or autz_policy is None:
return None
identity = await identity_policy.identify(request)
@@ -59,20 +67,21 @@ async def authorized_userid(request):
return user_id
-async def permits(request, permission, context=None):
+async def permits(request: web.Request, permission: Union[str, enum.Enum],
+ context: Any = None) -> bool:
assert isinstance(permission, (str, enum.Enum)), permission
assert permission
- identity_policy = request.config_dict.get(IDENTITY_KEY)
- autz_policy = request.config_dict.get(AUTZ_KEY)
+ identity_policy: _AIP = request.config_dict.get(IDENTITY_KEY)
+ autz_policy: _AAP = request.config_dict.get(AUTZ_KEY)
if identity_policy is None or autz_policy is None:
return True
identity = await identity_policy.identify(request)
- # non-registered user still may has some permissions
+ # non-registered user still may have some permissions
access = await autz_policy.permits(identity, permission, context)
return access
-async def is_anonymous(request):
+async def is_anonymous(request: web.Request) -> bool:
"""Check if user is anonymous.
User is considered anonymous if there is not identity
@@ -87,7 +96,7 @@ async def is_anonymous(request):
return False
-async def check_authorized(request):
+async def check_authorized(request: web.Request) -> str:
"""Checker that raises HTTPUnauthorized for anonymous users.
"""
userid = await authorized_userid(request)
@@ -96,31 +105,32 @@ async def check_authorized(request):
return userid
-def login_required(fn):
+def login_required(fn: _Handler) -> _Handler:
"""Decorator that restrict access only for authorized users.
User is considered authorized if authorized_userid
returns some value.
"""
@wraps(fn)
- async def wrapped(*args, **kwargs):
+ async def wrapped(*args: Union[object, web.Request]) -> Any:
request = args[-1]
- if not isinstance(request, web.BaseRequest):
+ if not isinstance(request, web.Request):
msg = ("Incorrect decorator usage. "
"Expecting `def handler(request)` "
"or `def handler(self, request)`.")
raise RuntimeError(msg)
await check_authorized(request)
- return await fn(*args, **kwargs)
+ return await fn(*args) # type: ignore[arg-type]
warnings.warn("login_required decorator is deprecated, "
"use check_authorized instead",
DeprecationWarning)
- return wrapped
+ return wrapped # type: ignore[return-value]
-async def check_permission(request, permission, context=None):
+async def check_permission(request: web.Request, permission: Union[str, enum.Enum],
+ context: Any = None) -> None:
"""Checker that passes only to authoraised users with given permission.
If user is not authorized - raises HTTPUnauthorized,
@@ -134,10 +144,7 @@ async def check_permission(request, permission, context=None):
raise web.HTTPForbidden()
-def has_permission(
- permission,
- context=None,
-):
+def has_permission(permission: Union[str, enum.Enum], context: Any = None): # type: ignore
"""Decorator that restricts access only for authorized users
with correct permissions.
@@ -145,11 +152,11 @@ def has_permission(
if user is authorized and does not have permission -
raises HTTPForbidden.
"""
- def wrapper(fn):
+ def wrapper(fn): # type: ignore
@wraps(fn)
- async def wrapped(*args, **kwargs):
+ async def wrapped(*args, **kwargs): # type: ignore
request = args[-1]
- if not isinstance(request, web.BaseRequest):
+ if not isinstance(request, web.Request):
msg = ("Incorrect decorator usage. "
"Expecting `def handler(request)` "
"or `def handler(self, request)`.")
@@ -166,7 +173,8 @@ def has_permission(
return wrapper
-def setup(app, identity_policy, autz_policy):
+def setup(app: web.Application, identity_policy: AbstractIdentityPolicy,
+ autz_policy: AbstractAuthorizationPolicy) -> None:
assert isinstance(identity_policy, AbstractIdentityPolicy), identity_policy
assert isinstance(autz_policy, AbstractAuthorizationPolicy), autz_policy
diff --git a/aiohttp_security/cookies_identity.py b/aiohttp_security/cookies_identity.py
index 3822ef9..f20c422 100644
--- a/aiohttp_security/cookies_identity.py
+++ b/aiohttp_security/cookies_identity.py
@@ -5,28 +5,32 @@ more handy.
"""
+from aiohttp import web
+from typing import Any, NewType, Optional, Union, cast
+
from .abc import AbstractIdentityPolicy
-
-sentinel = object()
+_Sentinel = NewType('_Sentinel', object)
+sentinel = _Sentinel(object())
class CookiesIdentityPolicy(AbstractIdentityPolicy):
- def __init__(self):
+ def __init__(self) -> None:
self._cookie_name = 'AIOHTTP_SECURITY'
self._max_age = 30 * 24 * 3600
- async def identify(self, request):
- identity = request.cookies.get(self._cookie_name)
- return identity
+ async def identify(self, request: web.Request) -> Optional[str]:
+ return request.cookies.get(self._cookie_name)
- async def remember(self, request, response, identity, max_age=sentinel,
- **kwargs):
+ async def remember(self, request: web.Request, response: web.StreamResponse,
+ identity: str, max_age: Union[_Sentinel, Optional[int]] = sentinel,
+ **kwargs: Any) -> None:
if max_age is sentinel:
max_age = self._max_age
+ max_age = cast(Optional[int], max_age)
response.set_cookie(self._cookie_name, identity,
max_age=max_age, **kwargs)
- async def forget(self, request, response):
+ async def forget(self, request: web.Request, response: web.StreamResponse) -> None:
response.del_cookie(self._cookie_name)
diff --git a/aiohttp_security/jwt_identity.py b/aiohttp_security/jwt_identity.py
index bd6ffe7..e456007 100644
--- a/aiohttp_security/jwt_identity.py
+++ b/aiohttp_security/jwt_identity.py
@@ -2,12 +2,17 @@
"""
+from typing import Optional
+
+from aiohttp import web
+
from .abc import AbstractIdentityPolicy
try:
import jwt
+ HAS_JWT = True
except ImportError: # pragma: no cover
- jwt = None
+ HAS_JWT = False
AUTH_HEADER_NAME = 'Authorization'
@@ -15,21 +20,21 @@ AUTH_SCHEME = 'Bearer '
class JWTIdentityPolicy(AbstractIdentityPolicy):
- def __init__(self, secret, algorithm='HS256'):
- if jwt is None:
+ def __init__(self, secret: str, algorithm: str = 'HS256'):
+ if not HAS_JWT:
raise RuntimeError('Please install `PyJWT`')
self.secret = secret
self.algorithm = algorithm
- async def identify(self, request):
+ async def identify(self, request: web.Request) -> Optional[str]:
header_identity = request.headers.get(AUTH_HEADER_NAME)
if header_identity is None:
- return
+ return None
if not header_identity.startswith(AUTH_SCHEME):
raise ValueError('Invalid authorization scheme. ' +
- 'Should be `Bearer
""") - async def index(self, request): + async def index(self, request: web.Request) -> web.Response: username = await authorized_userid(request) if username: template = self.index_template.format( @@ -37,37 +38,41 @@ class Web(object): response = web.Response(body=template.encode()) return response - async def login(self, request): - response = web.HTTPFound('/') + async def login(self, request: web.Request) -> NoReturn: + invalid_resp = web.HTTPUnauthorized(body=b'Invalid username/password combination') form = await request.post() login = form.get('login') password = form.get('password') - db_engine = request.app.db_engine + db_engine = request.app['db_engine'] + + if not (isinstance(login, str) and isinstance(password, str)): + raise invalid_resp + if await check_credentials(db_engine, login, password): + response = web.HTTPFound('/') await remember(request, response, login) raise response - raise web.HTTPUnauthorized( - body=b'Invalid username/password combination') + raise invalid_resp - async def logout(self, request): + async def logout(self, request: web.Request) -> web.Response: await check_authorized(request) response = web.Response(body=b'You have been logged out') await forget(request, response) return response - async def internal_page(self, request): + async def internal_page(self, request: web.Request) -> web.Response: await check_permission(request, 'public') response = web.Response( body=b'This page is visible for all registered users') return response - async def protected_page(self, request): + async def protected_page(self, request: web.Request) -> web.Response: await check_permission(request, 'protected') response = web.Response(body=b'You are on protected page') return response - def configure(self, app): + def configure(self, app: web.Application) -> None: router = app.router router.add_route('GET', '/', self.index, name='index') router.add_route('POST', '/login', self.login, name='login') diff --git a/demo/database_auth/main.py b/demo/database_auth/main.py index 822df44..b6b81d8 100644 --- a/demo/database_auth/main.py +++ b/demo/database_auth/main.py @@ -1,4 +1,5 @@ import asyncio +from typing import Any, Tuple from aiohttp import web from aiohttp_session import setup as setup_session @@ -13,14 +14,14 @@ from demo.database_auth.db_auth import DBAuthorizationPolicy from demo.database_auth.handlers import Web -async def init(loop): +async def init(loop: asyncio.AbstractEventLoop) -> Tuple[Any, ...]: redis_pool = await create_pool(('localhost', 6379)) db_engine = await create_engine(user='aiohttp_security', password='aiohttp_security', database='aiohttp_security', host='127.0.0.1') app = web.Application() - app.db_engine = db_engine + app['db_engine'] = db_engine setup_session(app, RedisStorage(redis_pool)) setup_security(app, SessionIdentityPolicy(), @@ -35,7 +36,7 @@ async def init(loop): return srv, app, handler -async def finalize(srv, app, handler): +async def finalize(srv: Any, app: Any, handler: Any) -> None: sock = srv.sockets[0] app.loop.remove_reader(sock.fileno()) sock.close() @@ -46,7 +47,7 @@ async def finalize(srv, app, handler): await app.finish() -def main(): +def main() -> None: loop = asyncio.get_event_loop() srv, app, handler = loop.run_until_complete(init(loop)) try: diff --git a/demo/dictionary_auth/__init__.py b/demo/dictionary_auth/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/demo/dictionary_auth/authz.py b/demo/dictionary_auth/authz.py index 0f9baae..57ef52a 100644 --- a/demo/dictionary_auth/authz.py +++ b/demo/dictionary_auth/authz.py @@ -1,20 +1,25 @@ +from enum import Enum +from typing import Dict, Optional, Union + from aiohttp_security.abc import AbstractAuthorizationPolicy +from .users import User + class DictionaryAuthorizationPolicy(AbstractAuthorizationPolicy): - def __init__(self, user_map): + def __init__(self, user_map: Dict[str, User]): super().__init__() self.user_map = user_map - async def authorized_userid(self, identity): + async def authorized_userid(self, identity: str) -> Optional[str]: """Retrieve authorized user id. Return the user_id of the user identified by the identity or 'None' if no user exists related to the identity. """ - if identity in self.user_map: - return identity + return identity if identity in self.user_map else None - async def permits(self, identity, permission, context=None): + async def permits(self, identity: str, permission: Union[str, Enum], + context: None = None) -> bool: """Check user permissions. Return True if the identity is allowed the permission in the current context, else return False. @@ -26,7 +31,7 @@ class DictionaryAuthorizationPolicy(AbstractAuthorizationPolicy): return permission in user.permissions -async def check_credentials(user_map, username, password): +async def check_credentials(user_map: Dict[str, User], username: str, password: str) -> bool: user = user_map.get(username) if not user: return False diff --git a/demo/dictionary_auth/handlers.py b/demo/dictionary_auth/handlers.py index 6c19bab..e558715 100644 --- a/demo/dictionary_auth/handlers.py +++ b/demo/dictionary_auth/handlers.py @@ -1,4 +1,5 @@ from textwrap import dedent +from typing import Dict, NoReturn from aiohttp import web @@ -8,6 +9,7 @@ from aiohttp_security import ( ) from .authz import check_credentials +from .users import User index_template = dedent(""" @@ -27,7 +29,7 @@ index_template = dedent(""" """) -async def index(request): +async def index(request: web.Request) -> web.Response: username = await authorized_userid(request) if username: template = index_template.format( @@ -40,22 +42,26 @@ async def index(request): ) -async def login(request): - response = web.HTTPFound('/') +async def login(request: web.Request) -> NoReturn: + user_map: Dict[str, User] = request.app['user_map'] + invalid_response = web.HTTPUnauthorized(body='Invalid username / password combination') form = await request.post() username = form.get('username') password = form.get('password') - verified = await check_credentials( - request.app.user_map, username, password) + if not (isinstance(username, str) and isinstance(password, str)): + raise invalid_response + + verified = await check_credentials(user_map, username, password) if verified: + response = web.HTTPFound('/') await remember(request, response, username) - return response + raise response - return web.HTTPUnauthorized(body='Invalid username / password combination') + raise invalid_response -async def logout(request): +async def logout(request: web.Request) -> web.Response: await check_authorized(request) response = web.Response( text='You have been logged out', @@ -65,7 +71,7 @@ async def logout(request): return response -async def internal_page(request): +async def internal_page(request: web.Request) -> web.Response: await check_permission(request, 'public') response = web.Response( text='This page is visible for all registered users', @@ -74,7 +80,7 @@ async def internal_page(request): return response -async def protected_page(request): +async def protected_page(request: web.Request) -> web.Response: await check_permission(request, 'protected') response = web.Response( text='You are on protected page', @@ -83,7 +89,7 @@ async def protected_page(request): return response -def configure_handlers(app): +def configure_handlers(app: web.Application) -> None: router = app.router router.add_get('/', index, name='index') router.add_post('/login', login, name='login') diff --git a/demo/dictionary_auth/main.py b/demo/dictionary_auth/main.py index b4fe2b4..0bd5f1a 100644 --- a/demo/dictionary_auth/main.py +++ b/demo/dictionary_auth/main.py @@ -11,9 +11,9 @@ from demo.dictionary_auth.handlers import configure_handlers from demo.dictionary_auth.users import user_map -def make_app(): +def make_app() -> web.Application: app = web.Application() - app.user_map = user_map + app['user_map'] = user_map configure_handlers(app) # secret_key must be 32 url-safe base64-encoded bytes diff --git a/demo/dictionary_auth/users.py b/demo/dictionary_auth/users.py index 967b2bb..65a9aa5 100644 --- a/demo/dictionary_auth/users.py +++ b/demo/dictionary_auth/users.py @@ -1,6 +1,11 @@ -from collections import namedtuple +from typing import NamedTuple, Tuple + + +class User(NamedTuple): + username: str + password: str + permissions: Tuple[str, ...] -User = namedtuple('User', ['username', 'password', 'permissions']) user_map = { user.username: user for user in [ diff --git a/demo/simple_example_auth.py b/demo/simple_example_auth.py index 5b4b82d..a613518 100644 --- a/demo/simple_example_auth.py +++ b/demo/simple_example_auth.py @@ -1,3 +1,6 @@ +from enum import Enum +from typing import NoReturn, Optional, Union + from aiohttp import web from aiohttp_session import SimpleCookieStorage, session_middleware from aiohttp_security import check_permission, \ @@ -11,15 +14,15 @@ from aiohttp_security.abc import AbstractAuthorizationPolicy # For more complicated authorization policies see examples # in the 'demo' directory. class SimpleJack_AuthorizationPolicy(AbstractAuthorizationPolicy): - async def authorized_userid(self, identity): + async def authorized_userid(self, identity: str) -> Optional[str]: """Retrieve authorized user id. Return the user_id of the user identified by the identity or 'None' if no user exists related to the identity. """ - if identity == 'jack': - return identity + return identity if identity == 'jack' else None - async def permits(self, identity, permission, context=None): + async def permits(self, identity: str, permission: Union[str, Enum], + context: None = None) -> bool: """Check user permissions. Return True if the identity is allowed the permission in the current context, else return False. @@ -27,7 +30,7 @@ class SimpleJack_AuthorizationPolicy(AbstractAuthorizationPolicy): return identity == 'jack' and permission in ('listen',) -async def handler_root(request): +async def handler_root(request: web.Request) -> web.Response: is_logged = not await is_anonymous(request) return web.Response(text='''