Add type annotations.
This commit is contained in:
parent
2247eb72f9
commit
b3b5367460
|
@ -0,0 +1,37 @@
|
||||||
|
[mypy]
|
||||||
|
files = aiohttp_security, demo, tests
|
||||||
|
check_untyped_defs = True
|
||||||
|
follow_imports_for_stubs = True
|
||||||
|
disallow_any_decorated = True
|
||||||
|
disallow_any_generics = True
|
||||||
|
disallow_incomplete_defs = True
|
||||||
|
disallow_subclassing_any = True
|
||||||
|
disallow_untyped_calls = True
|
||||||
|
disallow_untyped_decorators = True
|
||||||
|
disallow_untyped_defs = True
|
||||||
|
implicit_reexport = False
|
||||||
|
no_implicit_optional = True
|
||||||
|
show_error_codes = True
|
||||||
|
strict_equality = True
|
||||||
|
warn_incomplete_stub = True
|
||||||
|
warn_redundant_casts = True
|
||||||
|
warn_unreachable = True
|
||||||
|
warn_unused_ignores = True
|
||||||
|
disallow_any_unimported = True
|
||||||
|
warn_return_any = True
|
||||||
|
|
||||||
|
[mypy-aiohttp_security.abc.*]
|
||||||
|
disallow_any_decorated = False
|
||||||
|
|
||||||
|
[mypy-tests.*]
|
||||||
|
disallow_any_decorated = False
|
||||||
|
disallow_untyped_defs = False
|
||||||
|
|
||||||
|
[mypy-aiopg.*]
|
||||||
|
ignore_missing_imports = True
|
||||||
|
|
||||||
|
[mypy-aioredis.*]
|
||||||
|
ignore_missing_imports = True
|
||||||
|
|
||||||
|
[mypy-passlib.*]
|
||||||
|
ignore_missing_imports = True
|
|
@ -1,4 +1,8 @@
|
||||||
import abc
|
import abc
|
||||||
|
from enum import Enum
|
||||||
|
from typing import Any, Optional, Union
|
||||||
|
|
||||||
|
from aiohttp import web
|
||||||
|
|
||||||
# see http://plope.com/pyramid_auth_design_api_postmortem
|
# see http://plope.com/pyramid_auth_design_api_postmortem
|
||||||
|
|
||||||
|
@ -6,13 +10,14 @@ import abc
|
||||||
class AbstractIdentityPolicy(metaclass=abc.ABCMeta):
|
class AbstractIdentityPolicy(metaclass=abc.ABCMeta):
|
||||||
|
|
||||||
@abc.abstractmethod
|
@abc.abstractmethod
|
||||||
async def identify(self, request):
|
async def identify(self, request: web.Request) -> Optional[str]:
|
||||||
"""Return the claimed identity of the user associated request or
|
"""Return the claimed identity of the user associated request or
|
||||||
``None`` if no identity can be found associated with the request."""
|
``None`` if no identity can be found associated with the request."""
|
||||||
pass
|
pass
|
||||||
|
|
||||||
@abc.abstractmethod
|
@abc.abstractmethod
|
||||||
async def remember(self, request, response, identity, **kwargs):
|
async def remember(self, request: web.Request, response: web.StreamResponse,
|
||||||
|
identity: str, **kwargs: Any) -> None:
|
||||||
"""Remember identity.
|
"""Remember identity.
|
||||||
|
|
||||||
Modify response object by filling it's headers with remembered user.
|
Modify response object by filling it's headers with remembered user.
|
||||||
|
@ -23,7 +28,7 @@ class AbstractIdentityPolicy(metaclass=abc.ABCMeta):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
@abc.abstractmethod
|
@abc.abstractmethod
|
||||||
async def forget(self, request, response):
|
async def forget(self, request: web.Request, response: web.StreamResponse) -> None:
|
||||||
""" Modify response which can be used to 'forget' the
|
""" Modify response which can be used to 'forget' the
|
||||||
current identity on subsequent requests."""
|
current identity on subsequent requests."""
|
||||||
pass
|
pass
|
||||||
|
@ -32,7 +37,8 @@ class AbstractIdentityPolicy(metaclass=abc.ABCMeta):
|
||||||
class AbstractAuthorizationPolicy(metaclass=abc.ABCMeta):
|
class AbstractAuthorizationPolicy(metaclass=abc.ABCMeta):
|
||||||
|
|
||||||
@abc.abstractmethod
|
@abc.abstractmethod
|
||||||
async def permits(self, identity, permission, context=None):
|
async def permits(self, identity: str, permission: Union[str, Enum],
|
||||||
|
context: Any = None) -> bool:
|
||||||
"""Check user permissions.
|
"""Check user permissions.
|
||||||
|
|
||||||
Return True if the identity is allowed the permission in the
|
Return True if the identity is allowed the permission in the
|
||||||
|
@ -41,7 +47,7 @@ class AbstractAuthorizationPolicy(metaclass=abc.ABCMeta):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
@abc.abstractmethod
|
@abc.abstractmethod
|
||||||
async def authorized_userid(self, identity):
|
async def authorized_userid(self, identity: str) -> Optional[str]:
|
||||||
"""Retrieve authorized user id.
|
"""Retrieve authorized user id.
|
||||||
|
|
||||||
Return the user_id of the user identified by the identity
|
Return the user_id of the user identified by the identity
|
||||||
|
|
|
@ -1,15 +1,23 @@
|
||||||
import enum
|
import enum
|
||||||
import warnings
|
import warnings
|
||||||
from aiohttp import web
|
|
||||||
from aiohttp_security.abc import (AbstractIdentityPolicy,
|
|
||||||
AbstractAuthorizationPolicy)
|
|
||||||
from functools import wraps
|
from functools import wraps
|
||||||
|
from typing import Any, Callable, Optional, TypeVar, Union
|
||||||
|
|
||||||
|
from aiohttp import web
|
||||||
|
from aiohttp_security.abc import AbstractAuthorizationPolicy, AbstractIdentityPolicy
|
||||||
|
|
||||||
IDENTITY_KEY = 'aiohttp_security_identity_policy'
|
IDENTITY_KEY = 'aiohttp_security_identity_policy'
|
||||||
AUTZ_KEY = 'aiohttp_security_autz_policy'
|
AUTZ_KEY = 'aiohttp_security_autz_policy'
|
||||||
|
|
||||||
|
# _AIP/_AAP are shorthand for Optional[policy] when we retrieve from request.
|
||||||
|
_AAP = Optional[AbstractAuthorizationPolicy]
|
||||||
|
_AIP = Optional[AbstractIdentityPolicy]
|
||||||
|
_Handler = TypeVar('_Handler', bound=Union[Callable[[web.Request], Any],
|
||||||
|
Callable[[object, web.Request], Any]])
|
||||||
|
|
||||||
async def remember(request, response, identity, **kwargs):
|
|
||||||
|
async def remember(request: web.Request, response: web.StreamResponse,
|
||||||
|
identity: str, **kwargs: Any) -> None:
|
||||||
"""Remember identity into response.
|
"""Remember identity into response.
|
||||||
|
|
||||||
The action is performed by identity_policy.remember()
|
The action is performed by identity_policy.remember()
|
||||||
|
@ -30,7 +38,7 @@ async def remember(request, response, identity, **kwargs):
|
||||||
await identity_policy.remember(request, response, identity, **kwargs)
|
await identity_policy.remember(request, response, identity, **kwargs)
|
||||||
|
|
||||||
|
|
||||||
async def forget(request, response):
|
async def forget(request: web.Request, response: web.StreamResponse) -> None:
|
||||||
"""Forget previously remembered identity.
|
"""Forget previously remembered identity.
|
||||||
|
|
||||||
Usually it clears cookie or server-side storage to forget user
|
Usually it clears cookie or server-side storage to forget user
|
||||||
|
@ -47,9 +55,9 @@ async def forget(request, response):
|
||||||
await identity_policy.forget(request, response)
|
await identity_policy.forget(request, response)
|
||||||
|
|
||||||
|
|
||||||
async def authorized_userid(request):
|
async def authorized_userid(request: web.Request) -> Optional[str]:
|
||||||
identity_policy = request.config_dict.get(IDENTITY_KEY)
|
identity_policy: _AIP = request.config_dict.get(IDENTITY_KEY)
|
||||||
autz_policy = request.config_dict.get(AUTZ_KEY)
|
autz_policy: _AAP = request.config_dict.get(AUTZ_KEY)
|
||||||
if identity_policy is None or autz_policy is None:
|
if identity_policy is None or autz_policy is None:
|
||||||
return None
|
return None
|
||||||
identity = await identity_policy.identify(request)
|
identity = await identity_policy.identify(request)
|
||||||
|
@ -59,20 +67,21 @@ async def authorized_userid(request):
|
||||||
return user_id
|
return user_id
|
||||||
|
|
||||||
|
|
||||||
async def permits(request, permission, context=None):
|
async def permits(request: web.Request, permission: Union[str, enum.Enum],
|
||||||
|
context: Any = None) -> bool:
|
||||||
assert isinstance(permission, (str, enum.Enum)), permission
|
assert isinstance(permission, (str, enum.Enum)), permission
|
||||||
assert permission
|
assert permission
|
||||||
identity_policy = request.config_dict.get(IDENTITY_KEY)
|
identity_policy: _AIP = request.config_dict.get(IDENTITY_KEY)
|
||||||
autz_policy = request.config_dict.get(AUTZ_KEY)
|
autz_policy: _AAP = request.config_dict.get(AUTZ_KEY)
|
||||||
if identity_policy is None or autz_policy is None:
|
if identity_policy is None or autz_policy is None:
|
||||||
return True
|
return True
|
||||||
identity = await identity_policy.identify(request)
|
identity = await identity_policy.identify(request)
|
||||||
# non-registered user still may has some permissions
|
# non-registered user still may have some permissions
|
||||||
access = await autz_policy.permits(identity, permission, context)
|
access = await autz_policy.permits(identity, permission, context)
|
||||||
return access
|
return access
|
||||||
|
|
||||||
|
|
||||||
async def is_anonymous(request):
|
async def is_anonymous(request: web.Request) -> bool:
|
||||||
"""Check if user is anonymous.
|
"""Check if user is anonymous.
|
||||||
|
|
||||||
User is considered anonymous if there is not identity
|
User is considered anonymous if there is not identity
|
||||||
|
@ -87,7 +96,7 @@ async def is_anonymous(request):
|
||||||
return False
|
return False
|
||||||
|
|
||||||
|
|
||||||
async def check_authorized(request):
|
async def check_authorized(request: web.Request) -> str:
|
||||||
"""Checker that raises HTTPUnauthorized for anonymous users.
|
"""Checker that raises HTTPUnauthorized for anonymous users.
|
||||||
"""
|
"""
|
||||||
userid = await authorized_userid(request)
|
userid = await authorized_userid(request)
|
||||||
|
@ -96,31 +105,32 @@ async def check_authorized(request):
|
||||||
return userid
|
return userid
|
||||||
|
|
||||||
|
|
||||||
def login_required(fn):
|
def login_required(fn: _Handler) -> _Handler:
|
||||||
"""Decorator that restrict access only for authorized users.
|
"""Decorator that restrict access only for authorized users.
|
||||||
|
|
||||||
User is considered authorized if authorized_userid
|
User is considered authorized if authorized_userid
|
||||||
returns some value.
|
returns some value.
|
||||||
"""
|
"""
|
||||||
@wraps(fn)
|
@wraps(fn)
|
||||||
async def wrapped(*args, **kwargs):
|
async def wrapped(*args: Union[object, web.Request]) -> Any:
|
||||||
request = args[-1]
|
request = args[-1]
|
||||||
if not isinstance(request, web.BaseRequest):
|
if not isinstance(request, web.Request):
|
||||||
msg = ("Incorrect decorator usage. "
|
msg = ("Incorrect decorator usage. "
|
||||||
"Expecting `def handler(request)` "
|
"Expecting `def handler(request)` "
|
||||||
"or `def handler(self, request)`.")
|
"or `def handler(self, request)`.")
|
||||||
raise RuntimeError(msg)
|
raise RuntimeError(msg)
|
||||||
|
|
||||||
await check_authorized(request)
|
await check_authorized(request)
|
||||||
return await fn(*args, **kwargs)
|
return await fn(*args) # type: ignore[arg-type]
|
||||||
|
|
||||||
warnings.warn("login_required decorator is deprecated, "
|
warnings.warn("login_required decorator is deprecated, "
|
||||||
"use check_authorized instead",
|
"use check_authorized instead",
|
||||||
DeprecationWarning)
|
DeprecationWarning)
|
||||||
return wrapped
|
return wrapped # type: ignore[return-value]
|
||||||
|
|
||||||
|
|
||||||
async def check_permission(request, permission, context=None):
|
async def check_permission(request: web.Request, permission: Union[str, enum.Enum],
|
||||||
|
context: Any = None) -> None:
|
||||||
"""Checker that passes only to authoraised users with given permission.
|
"""Checker that passes only to authoraised users with given permission.
|
||||||
|
|
||||||
If user is not authorized - raises HTTPUnauthorized,
|
If user is not authorized - raises HTTPUnauthorized,
|
||||||
|
@ -134,10 +144,7 @@ async def check_permission(request, permission, context=None):
|
||||||
raise web.HTTPForbidden()
|
raise web.HTTPForbidden()
|
||||||
|
|
||||||
|
|
||||||
def has_permission(
|
def has_permission(permission: Union[str, enum.Enum], context: Any = None): # type: ignore
|
||||||
permission,
|
|
||||||
context=None,
|
|
||||||
):
|
|
||||||
"""Decorator that restricts access only for authorized users
|
"""Decorator that restricts access only for authorized users
|
||||||
with correct permissions.
|
with correct permissions.
|
||||||
|
|
||||||
|
@ -145,11 +152,11 @@ def has_permission(
|
||||||
if user is authorized and does not have permission -
|
if user is authorized and does not have permission -
|
||||||
raises HTTPForbidden.
|
raises HTTPForbidden.
|
||||||
"""
|
"""
|
||||||
def wrapper(fn):
|
def wrapper(fn): # type: ignore
|
||||||
@wraps(fn)
|
@wraps(fn)
|
||||||
async def wrapped(*args, **kwargs):
|
async def wrapped(*args, **kwargs): # type: ignore
|
||||||
request = args[-1]
|
request = args[-1]
|
||||||
if not isinstance(request, web.BaseRequest):
|
if not isinstance(request, web.Request):
|
||||||
msg = ("Incorrect decorator usage. "
|
msg = ("Incorrect decorator usage. "
|
||||||
"Expecting `def handler(request)` "
|
"Expecting `def handler(request)` "
|
||||||
"or `def handler(self, request)`.")
|
"or `def handler(self, request)`.")
|
||||||
|
@ -166,7 +173,8 @@ def has_permission(
|
||||||
return wrapper
|
return wrapper
|
||||||
|
|
||||||
|
|
||||||
def setup(app, identity_policy, autz_policy):
|
def setup(app: web.Application, identity_policy: AbstractIdentityPolicy,
|
||||||
|
autz_policy: AbstractAuthorizationPolicy) -> None:
|
||||||
assert isinstance(identity_policy, AbstractIdentityPolicy), identity_policy
|
assert isinstance(identity_policy, AbstractIdentityPolicy), identity_policy
|
||||||
assert isinstance(autz_policy, AbstractAuthorizationPolicy), autz_policy
|
assert isinstance(autz_policy, AbstractAuthorizationPolicy), autz_policy
|
||||||
|
|
||||||
|
|
|
@ -5,28 +5,32 @@ more handy.
|
||||||
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
from aiohttp import web
|
||||||
|
from typing import Any, NewType, Optional, Union, cast
|
||||||
|
|
||||||
from .abc import AbstractIdentityPolicy
|
from .abc import AbstractIdentityPolicy
|
||||||
|
|
||||||
|
_Sentinel = NewType('_Sentinel', object)
|
||||||
sentinel = object()
|
sentinel = _Sentinel(object())
|
||||||
|
|
||||||
|
|
||||||
class CookiesIdentityPolicy(AbstractIdentityPolicy):
|
class CookiesIdentityPolicy(AbstractIdentityPolicy):
|
||||||
|
|
||||||
def __init__(self):
|
def __init__(self) -> None:
|
||||||
self._cookie_name = 'AIOHTTP_SECURITY'
|
self._cookie_name = 'AIOHTTP_SECURITY'
|
||||||
self._max_age = 30 * 24 * 3600
|
self._max_age = 30 * 24 * 3600
|
||||||
|
|
||||||
async def identify(self, request):
|
async def identify(self, request: web.Request) -> Optional[str]:
|
||||||
identity = request.cookies.get(self._cookie_name)
|
return request.cookies.get(self._cookie_name)
|
||||||
return identity
|
|
||||||
|
|
||||||
async def remember(self, request, response, identity, max_age=sentinel,
|
async def remember(self, request: web.Request, response: web.StreamResponse,
|
||||||
**kwargs):
|
identity: str, max_age: Union[_Sentinel, Optional[int]] = sentinel,
|
||||||
|
**kwargs: Any) -> None:
|
||||||
if max_age is sentinel:
|
if max_age is sentinel:
|
||||||
max_age = self._max_age
|
max_age = self._max_age
|
||||||
|
max_age = cast(Optional[int], max_age)
|
||||||
response.set_cookie(self._cookie_name, identity,
|
response.set_cookie(self._cookie_name, identity,
|
||||||
max_age=max_age, **kwargs)
|
max_age=max_age, **kwargs)
|
||||||
|
|
||||||
async def forget(self, request, response):
|
async def forget(self, request: web.Request, response: web.StreamResponse) -> None:
|
||||||
response.del_cookie(self._cookie_name)
|
response.del_cookie(self._cookie_name)
|
||||||
|
|
|
@ -2,12 +2,17 @@
|
||||||
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
from aiohttp import web
|
||||||
|
|
||||||
from .abc import AbstractIdentityPolicy
|
from .abc import AbstractIdentityPolicy
|
||||||
|
|
||||||
try:
|
try:
|
||||||
import jwt
|
import jwt
|
||||||
|
HAS_JWT = True
|
||||||
except ImportError: # pragma: no cover
|
except ImportError: # pragma: no cover
|
||||||
jwt = None
|
HAS_JWT = False
|
||||||
|
|
||||||
|
|
||||||
AUTH_HEADER_NAME = 'Authorization'
|
AUTH_HEADER_NAME = 'Authorization'
|
||||||
|
@ -15,21 +20,21 @@ AUTH_SCHEME = 'Bearer '
|
||||||
|
|
||||||
|
|
||||||
class JWTIdentityPolicy(AbstractIdentityPolicy):
|
class JWTIdentityPolicy(AbstractIdentityPolicy):
|
||||||
def __init__(self, secret, algorithm='HS256'):
|
def __init__(self, secret: str, algorithm: str = 'HS256'):
|
||||||
if jwt is None:
|
if not HAS_JWT:
|
||||||
raise RuntimeError('Please install `PyJWT`')
|
raise RuntimeError('Please install `PyJWT`')
|
||||||
self.secret = secret
|
self.secret = secret
|
||||||
self.algorithm = algorithm
|
self.algorithm = algorithm
|
||||||
|
|
||||||
async def identify(self, request):
|
async def identify(self, request: web.Request) -> Optional[str]:
|
||||||
header_identity = request.headers.get(AUTH_HEADER_NAME)
|
header_identity = request.headers.get(AUTH_HEADER_NAME)
|
||||||
|
|
||||||
if header_identity is None:
|
if header_identity is None:
|
||||||
return
|
return None
|
||||||
|
|
||||||
if not header_identity.startswith(AUTH_SCHEME):
|
if not header_identity.startswith(AUTH_SCHEME):
|
||||||
raise ValueError('Invalid authorization scheme. ' +
|
raise ValueError('Invalid authorization scheme. ' +
|
||||||
'Should be `Bearer <token>`')
|
'Should be `{}<token>`'.format(AUTH_SCHEME))
|
||||||
|
|
||||||
token = header_identity.split(' ')[1].strip()
|
token = header_identity.split(' ')[1].strip()
|
||||||
|
|
||||||
|
@ -38,8 +43,9 @@ class JWTIdentityPolicy(AbstractIdentityPolicy):
|
||||||
algorithms=[self.algorithm])
|
algorithms=[self.algorithm])
|
||||||
return identity
|
return identity
|
||||||
|
|
||||||
async def remember(self, *args, **kwargs): # pragma: no cover
|
async def remember(self, request: web.Request, response: web.StreamResponse,
|
||||||
|
identity: str, **kwargs: None) -> None:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
async def forget(self, request, response): # pragma: no cover
|
async def forget(self, request: web.Request, response: web.StreamResponse) -> None:
|
||||||
pass
|
pass
|
||||||
|
|
|
@ -4,6 +4,7 @@ aiohttp_session.setup() should be called on application initialization
|
||||||
to configure aiohttp_session properly.
|
to configure aiohttp_session properly.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
from aiohttp import web
|
||||||
try:
|
try:
|
||||||
from aiohttp_session import get_session
|
from aiohttp_session import get_session
|
||||||
HAS_AIOHTTP_SESSION = True
|
HAS_AIOHTTP_SESSION = True
|
||||||
|
@ -15,21 +16,22 @@ from .abc import AbstractIdentityPolicy
|
||||||
|
|
||||||
class SessionIdentityPolicy(AbstractIdentityPolicy):
|
class SessionIdentityPolicy(AbstractIdentityPolicy):
|
||||||
|
|
||||||
def __init__(self, session_key='AIOHTTP_SECURITY'):
|
def __init__(self, session_key: str = 'AIOHTTP_SECURITY'):
|
||||||
self._session_key = session_key
|
self._session_key = session_key
|
||||||
|
|
||||||
if not HAS_AIOHTTP_SESSION: # pragma: no cover
|
if not HAS_AIOHTTP_SESSION: # pragma: no cover
|
||||||
raise ImportError(
|
raise ImportError(
|
||||||
'SessionIdentityPolicy requires `aiohttp_session`')
|
'SessionIdentityPolicy requires `aiohttp_session`')
|
||||||
|
|
||||||
async def identify(self, request):
|
async def identify(self, request: web.Request) -> str:
|
||||||
session = await get_session(request)
|
session = await get_session(request)
|
||||||
return session.get(self._session_key)
|
return session.get(self._session_key)
|
||||||
|
|
||||||
async def remember(self, request, response, identity, **kwargs):
|
async def remember(self, request: web.Request, response: web.StreamResponse,
|
||||||
|
identity: str, **kwargs: None) -> None:
|
||||||
session = await get_session(request)
|
session = await get_session(request)
|
||||||
session[self._session_key] = identity
|
session[self._session_key] = identity
|
||||||
|
|
||||||
async def forget(self, request, response):
|
async def forget(self, request: web.Request, response: web.StreamResponse) -> None:
|
||||||
session = await get_session(request)
|
session = await get_session(request)
|
||||||
session.pop(self._session_key, None)
|
session.pop(self._session_key, None)
|
||||||
|
|
|
@ -1,5 +1,7 @@
|
||||||
import sqlalchemy as sa
|
from enum import Enum
|
||||||
|
from typing import Any, Optional, Union
|
||||||
|
|
||||||
|
import sqlalchemy as sa
|
||||||
from aiohttp_security.abc import AbstractAuthorizationPolicy
|
from aiohttp_security.abc import AbstractAuthorizationPolicy
|
||||||
from passlib.hash import sha256_crypt
|
from passlib.hash import sha256_crypt
|
||||||
|
|
||||||
|
@ -7,13 +9,13 @@ from . import db
|
||||||
|
|
||||||
|
|
||||||
class DBAuthorizationPolicy(AbstractAuthorizationPolicy):
|
class DBAuthorizationPolicy(AbstractAuthorizationPolicy):
|
||||||
def __init__(self, dbengine):
|
def __init__(self, dbengine: Any):
|
||||||
self.dbengine = dbengine
|
self.dbengine = dbengine
|
||||||
|
|
||||||
async def authorized_userid(self, identity):
|
async def authorized_userid(self, identity: str) -> Optional[str]:
|
||||||
async with self.dbengine.acquire() as conn:
|
async with self.dbengine.acquire() as conn:
|
||||||
where = sa.and_(db.users.c.login == identity,
|
where = sa.and_(db.users.c.login == identity,
|
||||||
sa.not_(db.users.c.disabled))
|
sa.not_(db.users.c.disabled)) # type: ignore[no-untyped-call]
|
||||||
query = db.users.count().where(where)
|
query = db.users.count().where(where)
|
||||||
ret = await conn.scalar(query)
|
ret = await conn.scalar(query)
|
||||||
if ret:
|
if ret:
|
||||||
|
@ -21,13 +23,11 @@ class DBAuthorizationPolicy(AbstractAuthorizationPolicy):
|
||||||
else:
|
else:
|
||||||
return None
|
return None
|
||||||
|
|
||||||
async def permits(self, identity, permission, context=None):
|
async def permits(self, identity: str, permission: Union[str, Enum],
|
||||||
if identity is None:
|
context: None = None) -> bool:
|
||||||
return False
|
|
||||||
|
|
||||||
async with self.dbengine.acquire() as conn:
|
async with self.dbengine.acquire() as conn:
|
||||||
where = sa.and_(db.users.c.login == identity,
|
where = sa.and_(db.users.c.login == identity,
|
||||||
sa.not_(db.users.c.disabled))
|
sa.not_(db.users.c.disabled)) # type: ignore[no-untyped-call]
|
||||||
query = db.users.select().where(where)
|
query = db.users.select().where(where)
|
||||||
ret = await conn.execute(query)
|
ret = await conn.execute(query)
|
||||||
user = await ret.fetchone()
|
user = await ret.fetchone()
|
||||||
|
@ -49,14 +49,14 @@ class DBAuthorizationPolicy(AbstractAuthorizationPolicy):
|
||||||
return False
|
return False
|
||||||
|
|
||||||
|
|
||||||
async def check_credentials(db_engine, username, password):
|
async def check_credentials(db_engine: Any, username: str, password: str) -> bool:
|
||||||
async with db_engine.acquire() as conn:
|
async with db_engine.acquire() as conn:
|
||||||
where = sa.and_(db.users.c.login == username,
|
where = sa.and_(db.users.c.login == username,
|
||||||
sa.not_(db.users.c.disabled))
|
sa.not_(db.users.c.disabled)) # type: ignore[no-untyped-call]
|
||||||
query = db.users.select().where(where)
|
query = db.users.select().where(where)
|
||||||
ret = await conn.execute(query)
|
ret = await conn.execute(query)
|
||||||
user = await ret.fetchone()
|
user = await ret.fetchone()
|
||||||
if user is not None:
|
if user is not None:
|
||||||
hashed = user[2]
|
hashed = user[2]
|
||||||
return sha256_crypt.verify(password, hashed)
|
return sha256_crypt.verify(password, hashed) # type: ignore[no-any-return]
|
||||||
return False
|
return False
|
||||||
|
|
|
@ -1,4 +1,5 @@
|
||||||
from textwrap import dedent
|
from textwrap import dedent
|
||||||
|
from typing import NoReturn
|
||||||
|
|
||||||
from aiohttp import web
|
from aiohttp import web
|
||||||
|
|
||||||
|
@ -27,7 +28,7 @@ class Web(object):
|
||||||
</body>
|
</body>
|
||||||
""")
|
""")
|
||||||
|
|
||||||
async def index(self, request):
|
async def index(self, request: web.Request) -> web.Response:
|
||||||
username = await authorized_userid(request)
|
username = await authorized_userid(request)
|
||||||
if username:
|
if username:
|
||||||
template = self.index_template.format(
|
template = self.index_template.format(
|
||||||
|
@ -37,37 +38,41 @@ class Web(object):
|
||||||
response = web.Response(body=template.encode())
|
response = web.Response(body=template.encode())
|
||||||
return response
|
return response
|
||||||
|
|
||||||
async def login(self, request):
|
async def login(self, request: web.Request) -> NoReturn:
|
||||||
response = web.HTTPFound('/')
|
invalid_resp = web.HTTPUnauthorized(body=b'Invalid username/password combination')
|
||||||
form = await request.post()
|
form = await request.post()
|
||||||
login = form.get('login')
|
login = form.get('login')
|
||||||
password = form.get('password')
|
password = form.get('password')
|
||||||
db_engine = request.app.db_engine
|
db_engine = request.app['db_engine']
|
||||||
|
|
||||||
|
if not (isinstance(login, str) and isinstance(password, str)):
|
||||||
|
raise invalid_resp
|
||||||
|
|
||||||
if await check_credentials(db_engine, login, password):
|
if await check_credentials(db_engine, login, password):
|
||||||
|
response = web.HTTPFound('/')
|
||||||
await remember(request, response, login)
|
await remember(request, response, login)
|
||||||
raise response
|
raise response
|
||||||
|
|
||||||
raise web.HTTPUnauthorized(
|
raise invalid_resp
|
||||||
body=b'Invalid username/password combination')
|
|
||||||
|
|
||||||
async def logout(self, request):
|
async def logout(self, request: web.Request) -> web.Response:
|
||||||
await check_authorized(request)
|
await check_authorized(request)
|
||||||
response = web.Response(body=b'You have been logged out')
|
response = web.Response(body=b'You have been logged out')
|
||||||
await forget(request, response)
|
await forget(request, response)
|
||||||
return response
|
return response
|
||||||
|
|
||||||
async def internal_page(self, request):
|
async def internal_page(self, request: web.Request) -> web.Response:
|
||||||
await check_permission(request, 'public')
|
await check_permission(request, 'public')
|
||||||
response = web.Response(
|
response = web.Response(
|
||||||
body=b'This page is visible for all registered users')
|
body=b'This page is visible for all registered users')
|
||||||
return response
|
return response
|
||||||
|
|
||||||
async def protected_page(self, request):
|
async def protected_page(self, request: web.Request) -> web.Response:
|
||||||
await check_permission(request, 'protected')
|
await check_permission(request, 'protected')
|
||||||
response = web.Response(body=b'You are on protected page')
|
response = web.Response(body=b'You are on protected page')
|
||||||
return response
|
return response
|
||||||
|
|
||||||
def configure(self, app):
|
def configure(self, app: web.Application) -> None:
|
||||||
router = app.router
|
router = app.router
|
||||||
router.add_route('GET', '/', self.index, name='index')
|
router.add_route('GET', '/', self.index, name='index')
|
||||||
router.add_route('POST', '/login', self.login, name='login')
|
router.add_route('POST', '/login', self.login, name='login')
|
||||||
|
|
|
@ -1,4 +1,5 @@
|
||||||
import asyncio
|
import asyncio
|
||||||
|
from typing import Any, Tuple
|
||||||
|
|
||||||
from aiohttp import web
|
from aiohttp import web
|
||||||
from aiohttp_session import setup as setup_session
|
from aiohttp_session import setup as setup_session
|
||||||
|
@ -13,14 +14,14 @@ from demo.database_auth.db_auth import DBAuthorizationPolicy
|
||||||
from demo.database_auth.handlers import Web
|
from demo.database_auth.handlers import Web
|
||||||
|
|
||||||
|
|
||||||
async def init(loop):
|
async def init(loop: asyncio.AbstractEventLoop) -> Tuple[Any, ...]:
|
||||||
redis_pool = await create_pool(('localhost', 6379))
|
redis_pool = await create_pool(('localhost', 6379))
|
||||||
db_engine = await create_engine(user='aiohttp_security',
|
db_engine = await create_engine(user='aiohttp_security',
|
||||||
password='aiohttp_security',
|
password='aiohttp_security',
|
||||||
database='aiohttp_security',
|
database='aiohttp_security',
|
||||||
host='127.0.0.1')
|
host='127.0.0.1')
|
||||||
app = web.Application()
|
app = web.Application()
|
||||||
app.db_engine = db_engine
|
app['db_engine'] = db_engine
|
||||||
setup_session(app, RedisStorage(redis_pool))
|
setup_session(app, RedisStorage(redis_pool))
|
||||||
setup_security(app,
|
setup_security(app,
|
||||||
SessionIdentityPolicy(),
|
SessionIdentityPolicy(),
|
||||||
|
@ -35,7 +36,7 @@ async def init(loop):
|
||||||
return srv, app, handler
|
return srv, app, handler
|
||||||
|
|
||||||
|
|
||||||
async def finalize(srv, app, handler):
|
async def finalize(srv: Any, app: Any, handler: Any) -> None:
|
||||||
sock = srv.sockets[0]
|
sock = srv.sockets[0]
|
||||||
app.loop.remove_reader(sock.fileno())
|
app.loop.remove_reader(sock.fileno())
|
||||||
sock.close()
|
sock.close()
|
||||||
|
@ -46,7 +47,7 @@ async def finalize(srv, app, handler):
|
||||||
await app.finish()
|
await app.finish()
|
||||||
|
|
||||||
|
|
||||||
def main():
|
def main() -> None:
|
||||||
loop = asyncio.get_event_loop()
|
loop = asyncio.get_event_loop()
|
||||||
srv, app, handler = loop.run_until_complete(init(loop))
|
srv, app, handler = loop.run_until_complete(init(loop))
|
||||||
try:
|
try:
|
||||||
|
|
|
@ -1,20 +1,25 @@
|
||||||
|
from enum import Enum
|
||||||
|
from typing import Dict, Optional, Union
|
||||||
|
|
||||||
from aiohttp_security.abc import AbstractAuthorizationPolicy
|
from aiohttp_security.abc import AbstractAuthorizationPolicy
|
||||||
|
|
||||||
|
from .users import User
|
||||||
|
|
||||||
|
|
||||||
class DictionaryAuthorizationPolicy(AbstractAuthorizationPolicy):
|
class DictionaryAuthorizationPolicy(AbstractAuthorizationPolicy):
|
||||||
def __init__(self, user_map):
|
def __init__(self, user_map: Dict[str, User]):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.user_map = user_map
|
self.user_map = user_map
|
||||||
|
|
||||||
async def authorized_userid(self, identity):
|
async def authorized_userid(self, identity: str) -> Optional[str]:
|
||||||
"""Retrieve authorized user id.
|
"""Retrieve authorized user id.
|
||||||
Return the user_id of the user identified by the identity
|
Return the user_id of the user identified by the identity
|
||||||
or 'None' if no user exists related to the identity.
|
or 'None' if no user exists related to the identity.
|
||||||
"""
|
"""
|
||||||
if identity in self.user_map:
|
return identity if identity in self.user_map else None
|
||||||
return identity
|
|
||||||
|
|
||||||
async def permits(self, identity, permission, context=None):
|
async def permits(self, identity: str, permission: Union[str, Enum],
|
||||||
|
context: None = None) -> bool:
|
||||||
"""Check user permissions.
|
"""Check user permissions.
|
||||||
Return True if the identity is allowed the permission in the
|
Return True if the identity is allowed the permission in the
|
||||||
current context, else return False.
|
current context, else return False.
|
||||||
|
@ -26,7 +31,7 @@ class DictionaryAuthorizationPolicy(AbstractAuthorizationPolicy):
|
||||||
return permission in user.permissions
|
return permission in user.permissions
|
||||||
|
|
||||||
|
|
||||||
async def check_credentials(user_map, username, password):
|
async def check_credentials(user_map: Dict[str, User], username: str, password: str) -> bool:
|
||||||
user = user_map.get(username)
|
user = user_map.get(username)
|
||||||
if not user:
|
if not user:
|
||||||
return False
|
return False
|
||||||
|
|
|
@ -1,4 +1,5 @@
|
||||||
from textwrap import dedent
|
from textwrap import dedent
|
||||||
|
from typing import Dict, NoReturn
|
||||||
|
|
||||||
from aiohttp import web
|
from aiohttp import web
|
||||||
|
|
||||||
|
@ -8,6 +9,7 @@ from aiohttp_security import (
|
||||||
)
|
)
|
||||||
|
|
||||||
from .authz import check_credentials
|
from .authz import check_credentials
|
||||||
|
from .users import User
|
||||||
|
|
||||||
|
|
||||||
index_template = dedent("""
|
index_template = dedent("""
|
||||||
|
@ -27,7 +29,7 @@ index_template = dedent("""
|
||||||
""")
|
""")
|
||||||
|
|
||||||
|
|
||||||
async def index(request):
|
async def index(request: web.Request) -> web.Response:
|
||||||
username = await authorized_userid(request)
|
username = await authorized_userid(request)
|
||||||
if username:
|
if username:
|
||||||
template = index_template.format(
|
template = index_template.format(
|
||||||
|
@ -40,22 +42,26 @@ async def index(request):
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
async def login(request):
|
async def login(request: web.Request) -> NoReturn:
|
||||||
response = web.HTTPFound('/')
|
user_map: Dict[str, User] = request.app['user_map']
|
||||||
|
invalid_response = web.HTTPUnauthorized(body='Invalid username / password combination')
|
||||||
form = await request.post()
|
form = await request.post()
|
||||||
username = form.get('username')
|
username = form.get('username')
|
||||||
password = form.get('password')
|
password = form.get('password')
|
||||||
|
|
||||||
verified = await check_credentials(
|
if not (isinstance(username, str) and isinstance(password, str)):
|
||||||
request.app.user_map, username, password)
|
raise invalid_response
|
||||||
|
|
||||||
|
verified = await check_credentials(user_map, username, password)
|
||||||
if verified:
|
if verified:
|
||||||
|
response = web.HTTPFound('/')
|
||||||
await remember(request, response, username)
|
await remember(request, response, username)
|
||||||
return response
|
raise response
|
||||||
|
|
||||||
return web.HTTPUnauthorized(body='Invalid username / password combination')
|
raise invalid_response
|
||||||
|
|
||||||
|
|
||||||
async def logout(request):
|
async def logout(request: web.Request) -> web.Response:
|
||||||
await check_authorized(request)
|
await check_authorized(request)
|
||||||
response = web.Response(
|
response = web.Response(
|
||||||
text='You have been logged out',
|
text='You have been logged out',
|
||||||
|
@ -65,7 +71,7 @@ async def logout(request):
|
||||||
return response
|
return response
|
||||||
|
|
||||||
|
|
||||||
async def internal_page(request):
|
async def internal_page(request: web.Request) -> web.Response:
|
||||||
await check_permission(request, 'public')
|
await check_permission(request, 'public')
|
||||||
response = web.Response(
|
response = web.Response(
|
||||||
text='This page is visible for all registered users',
|
text='This page is visible for all registered users',
|
||||||
|
@ -74,7 +80,7 @@ async def internal_page(request):
|
||||||
return response
|
return response
|
||||||
|
|
||||||
|
|
||||||
async def protected_page(request):
|
async def protected_page(request: web.Request) -> web.Response:
|
||||||
await check_permission(request, 'protected')
|
await check_permission(request, 'protected')
|
||||||
response = web.Response(
|
response = web.Response(
|
||||||
text='You are on protected page',
|
text='You are on protected page',
|
||||||
|
@ -83,7 +89,7 @@ async def protected_page(request):
|
||||||
return response
|
return response
|
||||||
|
|
||||||
|
|
||||||
def configure_handlers(app):
|
def configure_handlers(app: web.Application) -> None:
|
||||||
router = app.router
|
router = app.router
|
||||||
router.add_get('/', index, name='index')
|
router.add_get('/', index, name='index')
|
||||||
router.add_post('/login', login, name='login')
|
router.add_post('/login', login, name='login')
|
||||||
|
|
|
@ -11,9 +11,9 @@ from demo.dictionary_auth.handlers import configure_handlers
|
||||||
from demo.dictionary_auth.users import user_map
|
from demo.dictionary_auth.users import user_map
|
||||||
|
|
||||||
|
|
||||||
def make_app():
|
def make_app() -> web.Application:
|
||||||
app = web.Application()
|
app = web.Application()
|
||||||
app.user_map = user_map
|
app['user_map'] = user_map
|
||||||
configure_handlers(app)
|
configure_handlers(app)
|
||||||
|
|
||||||
# secret_key must be 32 url-safe base64-encoded bytes
|
# secret_key must be 32 url-safe base64-encoded bytes
|
||||||
|
|
|
@ -1,6 +1,11 @@
|
||||||
from collections import namedtuple
|
from typing import NamedTuple, Tuple
|
||||||
|
|
||||||
|
|
||||||
|
class User(NamedTuple):
|
||||||
|
username: str
|
||||||
|
password: str
|
||||||
|
permissions: Tuple[str, ...]
|
||||||
|
|
||||||
User = namedtuple('User', ['username', 'password', 'permissions'])
|
|
||||||
|
|
||||||
user_map = {
|
user_map = {
|
||||||
user.username: user for user in [
|
user.username: user for user in [
|
||||||
|
|
|
@ -1,3 +1,6 @@
|
||||||
|
from enum import Enum
|
||||||
|
from typing import NoReturn, Optional, Union
|
||||||
|
|
||||||
from aiohttp import web
|
from aiohttp import web
|
||||||
from aiohttp_session import SimpleCookieStorage, session_middleware
|
from aiohttp_session import SimpleCookieStorage, session_middleware
|
||||||
from aiohttp_security import check_permission, \
|
from aiohttp_security import check_permission, \
|
||||||
|
@ -11,15 +14,15 @@ from aiohttp_security.abc import AbstractAuthorizationPolicy
|
||||||
# For more complicated authorization policies see examples
|
# For more complicated authorization policies see examples
|
||||||
# in the 'demo' directory.
|
# in the 'demo' directory.
|
||||||
class SimpleJack_AuthorizationPolicy(AbstractAuthorizationPolicy):
|
class SimpleJack_AuthorizationPolicy(AbstractAuthorizationPolicy):
|
||||||
async def authorized_userid(self, identity):
|
async def authorized_userid(self, identity: str) -> Optional[str]:
|
||||||
"""Retrieve authorized user id.
|
"""Retrieve authorized user id.
|
||||||
Return the user_id of the user identified by the identity
|
Return the user_id of the user identified by the identity
|
||||||
or 'None' if no user exists related to the identity.
|
or 'None' if no user exists related to the identity.
|
||||||
"""
|
"""
|
||||||
if identity == 'jack':
|
return identity if identity == 'jack' else None
|
||||||
return identity
|
|
||||||
|
|
||||||
async def permits(self, identity, permission, context=None):
|
async def permits(self, identity: str, permission: Union[str, Enum],
|
||||||
|
context: None = None) -> bool:
|
||||||
"""Check user permissions.
|
"""Check user permissions.
|
||||||
Return True if the identity is allowed the permission
|
Return True if the identity is allowed the permission
|
||||||
in the current context, else return False.
|
in the current context, else return False.
|
||||||
|
@ -27,7 +30,7 @@ class SimpleJack_AuthorizationPolicy(AbstractAuthorizationPolicy):
|
||||||
return identity == 'jack' and permission in ('listen',)
|
return identity == 'jack' and permission in ('listen',)
|
||||||
|
|
||||||
|
|
||||||
async def handler_root(request):
|
async def handler_root(request: web.Request) -> web.Response:
|
||||||
is_logged = not await is_anonymous(request)
|
is_logged = not await is_anonymous(request)
|
||||||
return web.Response(text='''<html><head></head><body>
|
return web.Response(text='''<html><head></head><body>
|
||||||
Hello, I'm Jack, I'm {logged} logged in.<br /><br />
|
Hello, I'm Jack, I'm {logged} logged in.<br /><br />
|
||||||
|
@ -42,29 +45,29 @@ async def handler_root(request):
|
||||||
), content_type='text/html')
|
), content_type='text/html')
|
||||||
|
|
||||||
|
|
||||||
async def handler_login_jack(request):
|
async def handler_login_jack(request: web.Request) -> NoReturn:
|
||||||
redirect_response = web.HTTPFound('/')
|
redirect_response = web.HTTPFound('/')
|
||||||
await remember(request, redirect_response, 'jack')
|
await remember(request, redirect_response, 'jack')
|
||||||
raise redirect_response
|
raise redirect_response
|
||||||
|
|
||||||
|
|
||||||
async def handler_logout(request):
|
async def handler_logout(request: web.Request) -> NoReturn:
|
||||||
redirect_response = web.HTTPFound('/')
|
redirect_response = web.HTTPFound('/')
|
||||||
await forget(request, redirect_response)
|
await forget(request, redirect_response)
|
||||||
raise redirect_response
|
raise redirect_response
|
||||||
|
|
||||||
|
|
||||||
async def handler_listen(request):
|
async def handler_listen(request: web.Request) -> web.Response:
|
||||||
await check_permission(request, 'listen')
|
await check_permission(request, 'listen')
|
||||||
return web.Response(body="I can listen!")
|
return web.Response(body="I can listen!")
|
||||||
|
|
||||||
|
|
||||||
async def handler_speak(request):
|
async def handler_speak(request: web.Request) -> web.Response:
|
||||||
await check_permission(request, 'speak')
|
await check_permission(request, 'speak')
|
||||||
return web.Response(body="I can speak!")
|
return web.Response(body="I can speak!")
|
||||||
|
|
||||||
|
|
||||||
async def make_app():
|
async def make_app() -> web.Application:
|
||||||
#
|
#
|
||||||
# WARNING!!!
|
# WARNING!!!
|
||||||
# Never use SimpleCookieStorage on production!!!
|
# Never use SimpleCookieStorage on production!!!
|
||||||
|
|
|
@ -67,7 +67,7 @@ async def test_identify_broken_scheme(loop, make_token, aiohttp_client):
|
||||||
try:
|
try:
|
||||||
await policy.identify(request)
|
await policy.identify(request)
|
||||||
except ValueError as exc:
|
except ValueError as exc:
|
||||||
raise web.HTTPBadRequest(reason=exc)
|
raise web.HTTPBadRequest(reason=str(exc))
|
||||||
|
|
||||||
return web.Response()
|
return web.Response()
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue