26 Commits
typing ... ci

Author SHA1 Message Date
Sam Bull
dd3499008a Add name 2021-01-30 23:23:50 +00:00
Sam Bull
28c8e0f1bb Add name 2021-01-30 23:23:38 +00:00
Sam Bull
6211c9e782 Remove redundant install 2021-01-30 23:10:51 +00:00
Sam Bull
bfa9ebd6c4 Remove master 2021-01-30 23:10:12 +00:00
Sam Bull
1031308b0f Update .github/workflows/pypi.yml
Co-authored-by: Sviatoslav Sydorenko <sviat@redhat.com>
2021-01-30 23:07:09 +00:00
Sam Bull
b2424dd716 Update ci.yaml 2021-01-30 18:56:08 +00:00
Sam Bull
163069e3e9 Create pypi.yml 2021-01-30 18:55:24 +00:00
Sam Bull
b27d9013b9 Update ci.yaml 2021-01-17 15:14:51 +00:00
Sam Bull
9b2045d6ae Update ci.yaml 2021-01-17 13:54:20 +00:00
Sam Bull
d0fc1ed673 Update ci.yaml 2021-01-15 17:06:41 +00:00
Sam Bull
35c7625dcb Update ci.yaml 2021-01-15 17:00:21 +00:00
Sam Bull
8af169f639 Update .github/workflows/ci.yaml
Co-authored-by: Sviatoslav Sydorenko <sviat@redhat.com>
2021-01-15 12:38:10 +00:00
Sam Bull
e53cb1cc05 Update .github/workflows/ci.yaml
Co-authored-by: Sviatoslav Sydorenko <sviat@redhat.com>
2021-01-15 12:31:54 +00:00
Sam Bull
3d1675534d Update ci.yaml 2021-01-15 10:42:36 +00:00
Sam Bull
41ad3f64e7 Update ci.yaml 2021-01-15 10:41:44 +00:00
Sam Bull
ddc64c36e5 Update ci.yaml 2021-01-15 10:38:05 +00:00
Sam Bull
34a7b97cd5 Update ci.yaml 2021-01-15 10:36:50 +00:00
Sam Bull
36c9ea0015 Update ci.yaml 2021-01-15 10:36:32 +00:00
Sam Bull
3c2c28b694 Update Makefile 2021-01-15 10:26:45 +00:00
Sam Bull
663f953379 Update ci.yaml 2021-01-15 10:25:04 +00:00
Sam Bull
80cf5977d3 Update Makefile 2021-01-15 09:59:00 +00:00
Sam Bull
980695382f Update ci.yaml 2021-01-14 23:42:53 +00:00
Sam Bull
1c8ecc65e8 Update setup.cfg 2020-12-18 22:20:54 +00:00
Sam Bull
3b54ea334e Update setup.cfg 2020-12-18 22:15:00 +00:00
Sam Bull
d77613de91 Update requirements-dev.txt 2020-12-18 22:12:55 +00:00
Sam Bull
ead5cff442 Update ci.yaml 2020-12-18 18:15:21 +00:00
25 changed files with 164 additions and 218 deletions

View File

@@ -1,26 +1,27 @@
name: Test name: Tests
on: pull_request on: pull_request
jobs: jobs:
mypy:
name: Check annotations with Mypy
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v2
- uses: actions/setup-python@v2
- run: pip install aiohttp mypy
- run: mypy
test: test:
name: Tests name: Tests
runs-on: ubuntu-latest runs-on: ubuntu-latest
continue-on-error: ${{ matrix.experimental }}
strategy: strategy:
matrix: matrix:
python-version: [3.6, 3.7, 3.8, 3.9] python-version:
- 3.6
- 3.7
- 3.8
- 3.9
experimental: [false]
include:
- python-version: 3.10.0-alpha - 3.10.0
experimental: true
steps: steps:
- uses: actions/checkout@v2 - name: Checkout
uses: actions/checkout@v2
- name: Set up Python ${{ matrix.python-version }} - name: Set up Python ${{ matrix.python-version }}
uses: actions/setup-python@v2 uses: actions/setup-python@v2
with: with:
@@ -29,10 +30,11 @@ jobs:
run: | run: |
pip install --upgrade pip pip install --upgrade pip
pip install -r requirements-dev.txt pip install -r requirements-dev.txt
pip install codecov
- name: Run tests - name: Run tests
run: | run: |
make coverage make coverage
- name: Upload coverage to Codecov - name: Upload coverage
run: | uses: codecov/codecov-action@v1
codecov with:
file: ./coverage.xml
flags: unit

31
.github/workflows/pypi.yml vendored Normal file
View File

@@ -0,0 +1,31 @@
name: Publish to PyPI
on:
push:
tags: [ 'v*' ]
env:
DEFAULT_PYTHON: 3.9
jobs:
publish:
name: Publish to PyPI
runs-on: ubuntu-latest
steps:
- name: Checkout
uses: actions/checkout@v2
- name: Set up Python
uses: actions/setup-python@v2
with:
python-version: ${{ env.DEFAULT_PYTHON }}
- name: Install dependencies
run: |
pip install --upgrade build
- name: Build
run: |
python -m build
- name: Publish to PyPI
uses: pypa/gh-action-pypi-publish@master
with:
password: ${{ secrets.PYPI_API_TOKEN }}

View File

@@ -1,37 +0,0 @@
[mypy]
files = aiohttp_security, demo, tests
check_untyped_defs = True
follow_imports_for_stubs = True
disallow_any_decorated = True
disallow_any_generics = True
disallow_incomplete_defs = True
disallow_subclassing_any = True
disallow_untyped_calls = True
disallow_untyped_decorators = True
disallow_untyped_defs = True
implicit_reexport = False
no_implicit_optional = True
show_error_codes = True
strict_equality = True
warn_incomplete_stub = True
warn_redundant_casts = True
warn_unreachable = True
warn_unused_ignores = True
disallow_any_unimported = True
warn_return_any = True
[mypy-aiohttp_security.abc.*]
disallow_any_decorated = False
[mypy-tests.*]
disallow_any_decorated = False
disallow_untyped_defs = False
[mypy-aiopg.*]
ignore_missing_imports = True
[mypy-aioredis.*]
ignore_missing_imports = True
[mypy-passlib.*]
ignore_missing_imports = True

View File

@@ -11,7 +11,7 @@ vtest: flake
py.test -s ./tests/ py.test -s ./tests/
cov cover coverage: flake cov cover coverage: flake
py.test -s ./tests/ --cov=aiohttp_security --cov=tests --cov-report=html --cov-report=term py.test -s ./tests/ --cov=aiohttp_security --cov=tests --cov-report=html --cov-report=xml --cov-report=term
@echo "open file://`pwd`/coverage/index.html" @echo "open file://`pwd`/coverage/index.html"
clean: clean:

View File

@@ -1,8 +1,4 @@
import abc import abc
from enum import Enum
from typing import Any, Optional, Union
from aiohttp import web
# see http://plope.com/pyramid_auth_design_api_postmortem # see http://plope.com/pyramid_auth_design_api_postmortem
@@ -10,14 +6,13 @@ from aiohttp import web
class AbstractIdentityPolicy(metaclass=abc.ABCMeta): class AbstractIdentityPolicy(metaclass=abc.ABCMeta):
@abc.abstractmethod @abc.abstractmethod
async def identify(self, request: web.Request) -> Optional[str]: async def identify(self, request):
"""Return the claimed identity of the user associated request or """Return the claimed identity of the user associated request or
``None`` if no identity can be found associated with the request.""" ``None`` if no identity can be found associated with the request."""
pass pass
@abc.abstractmethod @abc.abstractmethod
async def remember(self, request: web.Request, response: web.StreamResponse, async def remember(self, request, response, identity, **kwargs):
identity: str, **kwargs: Any) -> None:
"""Remember identity. """Remember identity.
Modify response object by filling it's headers with remembered user. Modify response object by filling it's headers with remembered user.
@@ -28,7 +23,7 @@ class AbstractIdentityPolicy(metaclass=abc.ABCMeta):
pass pass
@abc.abstractmethod @abc.abstractmethod
async def forget(self, request: web.Request, response: web.StreamResponse) -> None: async def forget(self, request, response):
""" Modify response which can be used to 'forget' the """ Modify response which can be used to 'forget' the
current identity on subsequent requests.""" current identity on subsequent requests."""
pass pass
@@ -37,8 +32,7 @@ class AbstractIdentityPolicy(metaclass=abc.ABCMeta):
class AbstractAuthorizationPolicy(metaclass=abc.ABCMeta): class AbstractAuthorizationPolicy(metaclass=abc.ABCMeta):
@abc.abstractmethod @abc.abstractmethod
async def permits(self, identity: str, permission: Union[str, Enum], async def permits(self, identity, permission, context=None):
context: Any = None) -> bool:
"""Check user permissions. """Check user permissions.
Return True if the identity is allowed the permission in the Return True if the identity is allowed the permission in the
@@ -47,7 +41,7 @@ class AbstractAuthorizationPolicy(metaclass=abc.ABCMeta):
pass pass
@abc.abstractmethod @abc.abstractmethod
async def authorized_userid(self, identity: str) -> Optional[str]: async def authorized_userid(self, identity):
"""Retrieve authorized user id. """Retrieve authorized user id.
Return the user_id of the user identified by the identity Return the user_id of the user identified by the identity

View File

@@ -1,23 +1,15 @@
import enum import enum
import warnings import warnings
from functools import wraps
from typing import Any, Callable, Optional, TypeVar, Union
from aiohttp import web from aiohttp import web
from aiohttp_security.abc import AbstractAuthorizationPolicy, AbstractIdentityPolicy from aiohttp_security.abc import (AbstractIdentityPolicy,
AbstractAuthorizationPolicy)
from functools import wraps
IDENTITY_KEY = 'aiohttp_security_identity_policy' IDENTITY_KEY = 'aiohttp_security_identity_policy'
AUTZ_KEY = 'aiohttp_security_autz_policy' AUTZ_KEY = 'aiohttp_security_autz_policy'
# _AIP/_AAP are shorthand for Optional[policy] when we retrieve from request.
_AAP = Optional[AbstractAuthorizationPolicy]
_AIP = Optional[AbstractIdentityPolicy]
_Handler = TypeVar('_Handler', bound=Union[Callable[[web.Request], Any],
Callable[[object, web.Request], Any]])
async def remember(request, response, identity, **kwargs):
async def remember(request: web.Request, response: web.StreamResponse,
identity: str, **kwargs: Any) -> None:
"""Remember identity into response. """Remember identity into response.
The action is performed by identity_policy.remember() The action is performed by identity_policy.remember()
@@ -38,7 +30,7 @@ async def remember(request: web.Request, response: web.StreamResponse,
await identity_policy.remember(request, response, identity, **kwargs) await identity_policy.remember(request, response, identity, **kwargs)
async def forget(request: web.Request, response: web.StreamResponse) -> None: async def forget(request, response):
"""Forget previously remembered identity. """Forget previously remembered identity.
Usually it clears cookie or server-side storage to forget user Usually it clears cookie or server-side storage to forget user
@@ -55,9 +47,9 @@ async def forget(request: web.Request, response: web.StreamResponse) -> None:
await identity_policy.forget(request, response) await identity_policy.forget(request, response)
async def authorized_userid(request: web.Request) -> Optional[str]: async def authorized_userid(request):
identity_policy: _AIP = request.config_dict.get(IDENTITY_KEY) identity_policy = request.config_dict.get(IDENTITY_KEY)
autz_policy: _AAP = request.config_dict.get(AUTZ_KEY) autz_policy = request.config_dict.get(AUTZ_KEY)
if identity_policy is None or autz_policy is None: if identity_policy is None or autz_policy is None:
return None return None
identity = await identity_policy.identify(request) identity = await identity_policy.identify(request)
@@ -67,21 +59,20 @@ async def authorized_userid(request: web.Request) -> Optional[str]:
return user_id return user_id
async def permits(request: web.Request, permission: Union[str, enum.Enum], async def permits(request, permission, context=None):
context: Any = None) -> bool:
assert isinstance(permission, (str, enum.Enum)), permission assert isinstance(permission, (str, enum.Enum)), permission
assert permission assert permission
identity_policy: _AIP = request.config_dict.get(IDENTITY_KEY) identity_policy = request.config_dict.get(IDENTITY_KEY)
autz_policy: _AAP = request.config_dict.get(AUTZ_KEY) autz_policy = request.config_dict.get(AUTZ_KEY)
if identity_policy is None or autz_policy is None: if identity_policy is None or autz_policy is None:
return True return True
identity = await identity_policy.identify(request) identity = await identity_policy.identify(request)
# non-registered user still may have some permissions # non-registered user still may has some permissions
access = await autz_policy.permits(identity, permission, context) access = await autz_policy.permits(identity, permission, context)
return access return access
async def is_anonymous(request: web.Request) -> bool: async def is_anonymous(request):
"""Check if user is anonymous. """Check if user is anonymous.
User is considered anonymous if there is not identity User is considered anonymous if there is not identity
@@ -96,7 +87,7 @@ async def is_anonymous(request: web.Request) -> bool:
return False return False
async def check_authorized(request: web.Request) -> str: async def check_authorized(request):
"""Checker that raises HTTPUnauthorized for anonymous users. """Checker that raises HTTPUnauthorized for anonymous users.
""" """
userid = await authorized_userid(request) userid = await authorized_userid(request)
@@ -105,32 +96,31 @@ async def check_authorized(request: web.Request) -> str:
return userid return userid
def login_required(fn: _Handler) -> _Handler: def login_required(fn):
"""Decorator that restrict access only for authorized users. """Decorator that restrict access only for authorized users.
User is considered authorized if authorized_userid User is considered authorized if authorized_userid
returns some value. returns some value.
""" """
@wraps(fn) @wraps(fn)
async def wrapped(*args: Union[object, web.Request]) -> Any: async def wrapped(*args, **kwargs):
request = args[-1] request = args[-1]
if not isinstance(request, web.Request): if not isinstance(request, web.BaseRequest):
msg = ("Incorrect decorator usage. " msg = ("Incorrect decorator usage. "
"Expecting `def handler(request)` " "Expecting `def handler(request)` "
"or `def handler(self, request)`.") "or `def handler(self, request)`.")
raise RuntimeError(msg) raise RuntimeError(msg)
await check_authorized(request) await check_authorized(request)
return await fn(*args) # type: ignore[arg-type] return await fn(*args, **kwargs)
warnings.warn("login_required decorator is deprecated, " warnings.warn("login_required decorator is deprecated, "
"use check_authorized instead", "use check_authorized instead",
DeprecationWarning) DeprecationWarning)
return wrapped # type: ignore[return-value] return wrapped
async def check_permission(request: web.Request, permission: Union[str, enum.Enum], async def check_permission(request, permission, context=None):
context: Any = None) -> None:
"""Checker that passes only to authoraised users with given permission. """Checker that passes only to authoraised users with given permission.
If user is not authorized - raises HTTPUnauthorized, If user is not authorized - raises HTTPUnauthorized,
@@ -144,7 +134,10 @@ async def check_permission(request: web.Request, permission: Union[str, enum.Enu
raise web.HTTPForbidden() raise web.HTTPForbidden()
def has_permission(permission: Union[str, enum.Enum], context: Any = None): # type: ignore def has_permission(
permission,
context=None,
):
"""Decorator that restricts access only for authorized users """Decorator that restricts access only for authorized users
with correct permissions. with correct permissions.
@@ -152,11 +145,11 @@ def has_permission(permission: Union[str, enum.Enum], context: Any = None): # t
if user is authorized and does not have permission - if user is authorized and does not have permission -
raises HTTPForbidden. raises HTTPForbidden.
""" """
def wrapper(fn): # type: ignore def wrapper(fn):
@wraps(fn) @wraps(fn)
async def wrapped(*args, **kwargs): # type: ignore async def wrapped(*args, **kwargs):
request = args[-1] request = args[-1]
if not isinstance(request, web.Request): if not isinstance(request, web.BaseRequest):
msg = ("Incorrect decorator usage. " msg = ("Incorrect decorator usage. "
"Expecting `def handler(request)` " "Expecting `def handler(request)` "
"or `def handler(self, request)`.") "or `def handler(self, request)`.")
@@ -173,8 +166,7 @@ def has_permission(permission: Union[str, enum.Enum], context: Any = None): # t
return wrapper return wrapper
def setup(app: web.Application, identity_policy: AbstractIdentityPolicy, def setup(app, identity_policy, autz_policy):
autz_policy: AbstractAuthorizationPolicy) -> None:
assert isinstance(identity_policy, AbstractIdentityPolicy), identity_policy assert isinstance(identity_policy, AbstractIdentityPolicy), identity_policy
assert isinstance(autz_policy, AbstractAuthorizationPolicy), autz_policy assert isinstance(autz_policy, AbstractAuthorizationPolicy), autz_policy

View File

@@ -5,32 +5,28 @@ more handy.
""" """
from aiohttp import web
from typing import Any, NewType, Optional, Union, cast
from .abc import AbstractIdentityPolicy from .abc import AbstractIdentityPolicy
_Sentinel = NewType('_Sentinel', object)
sentinel = _Sentinel(object()) sentinel = object()
class CookiesIdentityPolicy(AbstractIdentityPolicy): class CookiesIdentityPolicy(AbstractIdentityPolicy):
def __init__(self) -> None: def __init__(self):
self._cookie_name = 'AIOHTTP_SECURITY' self._cookie_name = 'AIOHTTP_SECURITY'
self._max_age = 30 * 24 * 3600 self._max_age = 30 * 24 * 3600
async def identify(self, request: web.Request) -> Optional[str]: async def identify(self, request):
return request.cookies.get(self._cookie_name) identity = request.cookies.get(self._cookie_name)
return identity
async def remember(self, request: web.Request, response: web.StreamResponse, async def remember(self, request, response, identity, max_age=sentinel,
identity: str, max_age: Union[_Sentinel, Optional[int]] = sentinel, **kwargs):
**kwargs: Any) -> None:
if max_age is sentinel: if max_age is sentinel:
max_age = self._max_age max_age = self._max_age
max_age = cast(Optional[int], max_age)
response.set_cookie(self._cookie_name, identity, response.set_cookie(self._cookie_name, identity,
max_age=max_age, **kwargs) max_age=max_age, **kwargs)
async def forget(self, request: web.Request, response: web.StreamResponse) -> None: async def forget(self, request, response):
response.del_cookie(self._cookie_name) response.del_cookie(self._cookie_name)

View File

@@ -2,17 +2,12 @@
""" """
from typing import Optional
from aiohttp import web
from .abc import AbstractIdentityPolicy from .abc import AbstractIdentityPolicy
try: try:
import jwt import jwt
HAS_JWT = True
except ImportError: # pragma: no cover except ImportError: # pragma: no cover
HAS_JWT = False jwt = None
AUTH_HEADER_NAME = 'Authorization' AUTH_HEADER_NAME = 'Authorization'
@@ -20,21 +15,21 @@ AUTH_SCHEME = 'Bearer '
class JWTIdentityPolicy(AbstractIdentityPolicy): class JWTIdentityPolicy(AbstractIdentityPolicy):
def __init__(self, secret: str, algorithm: str = 'HS256'): def __init__(self, secret, algorithm='HS256'):
if not HAS_JWT: if jwt is None:
raise RuntimeError('Please install `PyJWT`') raise RuntimeError('Please install `PyJWT`')
self.secret = secret self.secret = secret
self.algorithm = algorithm self.algorithm = algorithm
async def identify(self, request: web.Request) -> Optional[str]: async def identify(self, request):
header_identity = request.headers.get(AUTH_HEADER_NAME) header_identity = request.headers.get(AUTH_HEADER_NAME)
if header_identity is None: if header_identity is None:
return None return
if not header_identity.startswith(AUTH_SCHEME): if not header_identity.startswith(AUTH_SCHEME):
raise ValueError('Invalid authorization scheme. ' + raise ValueError('Invalid authorization scheme. ' +
'Should be `{}<token>`'.format(AUTH_SCHEME)) 'Should be `Bearer <token>`')
token = header_identity.split(' ')[1].strip() token = header_identity.split(' ')[1].strip()
@@ -43,9 +38,8 @@ class JWTIdentityPolicy(AbstractIdentityPolicy):
algorithms=[self.algorithm]) algorithms=[self.algorithm])
return identity return identity
async def remember(self, request: web.Request, response: web.StreamResponse, async def remember(self, *args, **kwargs): # pragma: no cover
identity: str, **kwargs: None) -> None:
pass pass
async def forget(self, request: web.Request, response: web.StreamResponse) -> None: async def forget(self, request, response): # pragma: no cover
pass pass

View File

@@ -4,7 +4,6 @@ aiohttp_session.setup() should be called on application initialization
to configure aiohttp_session properly. to configure aiohttp_session properly.
""" """
from aiohttp import web
try: try:
from aiohttp_session import get_session from aiohttp_session import get_session
HAS_AIOHTTP_SESSION = True HAS_AIOHTTP_SESSION = True
@@ -16,22 +15,21 @@ from .abc import AbstractIdentityPolicy
class SessionIdentityPolicy(AbstractIdentityPolicy): class SessionIdentityPolicy(AbstractIdentityPolicy):
def __init__(self, session_key: str = 'AIOHTTP_SECURITY'): def __init__(self, session_key='AIOHTTP_SECURITY'):
self._session_key = session_key self._session_key = session_key
if not HAS_AIOHTTP_SESSION: # pragma: no cover if not HAS_AIOHTTP_SESSION: # pragma: no cover
raise ImportError( raise ImportError(
'SessionIdentityPolicy requires `aiohttp_session`') 'SessionIdentityPolicy requires `aiohttp_session`')
async def identify(self, request: web.Request) -> str: async def identify(self, request):
session = await get_session(request) session = await get_session(request)
return session.get(self._session_key) return session.get(self._session_key)
async def remember(self, request: web.Request, response: web.StreamResponse, async def remember(self, request, response, identity, **kwargs):
identity: str, **kwargs: None) -> None:
session = await get_session(request) session = await get_session(request)
session[self._session_key] = identity session[self._session_key] = identity
async def forget(self, request: web.Request, response: web.StreamResponse) -> None: async def forget(self, request, response):
session = await get_session(request) session = await get_session(request)
session.pop(self._session_key, None) session.pop(self._session_key, None)

View File

View File

@@ -1,7 +1,5 @@
from enum import Enum
from typing import Any, Optional, Union
import sqlalchemy as sa import sqlalchemy as sa
from aiohttp_security.abc import AbstractAuthorizationPolicy from aiohttp_security.abc import AbstractAuthorizationPolicy
from passlib.hash import sha256_crypt from passlib.hash import sha256_crypt
@@ -9,13 +7,13 @@ from . import db
class DBAuthorizationPolicy(AbstractAuthorizationPolicy): class DBAuthorizationPolicy(AbstractAuthorizationPolicy):
def __init__(self, dbengine: Any): def __init__(self, dbengine):
self.dbengine = dbengine self.dbengine = dbengine
async def authorized_userid(self, identity: str) -> Optional[str]: async def authorized_userid(self, identity):
async with self.dbengine.acquire() as conn: async with self.dbengine.acquire() as conn:
where = sa.and_(db.users.c.login == identity, where = sa.and_(db.users.c.login == identity,
sa.not_(db.users.c.disabled)) # type: ignore[no-untyped-call] sa.not_(db.users.c.disabled))
query = db.users.count().where(where) query = db.users.count().where(where)
ret = await conn.scalar(query) ret = await conn.scalar(query)
if ret: if ret:
@@ -23,11 +21,13 @@ class DBAuthorizationPolicy(AbstractAuthorizationPolicy):
else: else:
return None return None
async def permits(self, identity: str, permission: Union[str, Enum], async def permits(self, identity, permission, context=None):
context: None = None) -> bool: if identity is None:
return False
async with self.dbengine.acquire() as conn: async with self.dbengine.acquire() as conn:
where = sa.and_(db.users.c.login == identity, where = sa.and_(db.users.c.login == identity,
sa.not_(db.users.c.disabled)) # type: ignore[no-untyped-call] sa.not_(db.users.c.disabled))
query = db.users.select().where(where) query = db.users.select().where(where)
ret = await conn.execute(query) ret = await conn.execute(query)
user = await ret.fetchone() user = await ret.fetchone()
@@ -49,14 +49,14 @@ class DBAuthorizationPolicy(AbstractAuthorizationPolicy):
return False return False
async def check_credentials(db_engine: Any, username: str, password: str) -> bool: async def check_credentials(db_engine, username, password):
async with db_engine.acquire() as conn: async with db_engine.acquire() as conn:
where = sa.and_(db.users.c.login == username, where = sa.and_(db.users.c.login == username,
sa.not_(db.users.c.disabled)) # type: ignore[no-untyped-call] sa.not_(db.users.c.disabled))
query = db.users.select().where(where) query = db.users.select().where(where)
ret = await conn.execute(query) ret = await conn.execute(query)
user = await ret.fetchone() user = await ret.fetchone()
if user is not None: if user is not None:
hashed = user[2] hashed = user[2]
return sha256_crypt.verify(password, hashed) # type: ignore[no-any-return] return sha256_crypt.verify(password, hashed)
return False return False

View File

@@ -1,5 +1,4 @@
from textwrap import dedent from textwrap import dedent
from typing import NoReturn
from aiohttp import web from aiohttp import web
@@ -28,7 +27,7 @@ class Web(object):
</body> </body>
""") """)
async def index(self, request: web.Request) -> web.Response: async def index(self, request):
username = await authorized_userid(request) username = await authorized_userid(request)
if username: if username:
template = self.index_template.format( template = self.index_template.format(
@@ -38,41 +37,37 @@ class Web(object):
response = web.Response(body=template.encode()) response = web.Response(body=template.encode())
return response return response
async def login(self, request: web.Request) -> NoReturn: async def login(self, request):
invalid_resp = web.HTTPUnauthorized(body=b'Invalid username/password combination') response = web.HTTPFound('/')
form = await request.post() form = await request.post()
login = form.get('login') login = form.get('login')
password = form.get('password') password = form.get('password')
db_engine = request.app['db_engine'] db_engine = request.app.db_engine
if not (isinstance(login, str) and isinstance(password, str)):
raise invalid_resp
if await check_credentials(db_engine, login, password): if await check_credentials(db_engine, login, password):
response = web.HTTPFound('/')
await remember(request, response, login) await remember(request, response, login)
raise response raise response
raise invalid_resp raise web.HTTPUnauthorized(
body=b'Invalid username/password combination')
async def logout(self, request: web.Request) -> web.Response: async def logout(self, request):
await check_authorized(request) await check_authorized(request)
response = web.Response(body=b'You have been logged out') response = web.Response(body=b'You have been logged out')
await forget(request, response) await forget(request, response)
return response return response
async def internal_page(self, request: web.Request) -> web.Response: async def internal_page(self, request):
await check_permission(request, 'public') await check_permission(request, 'public')
response = web.Response( response = web.Response(
body=b'This page is visible for all registered users') body=b'This page is visible for all registered users')
return response return response
async def protected_page(self, request: web.Request) -> web.Response: async def protected_page(self, request):
await check_permission(request, 'protected') await check_permission(request, 'protected')
response = web.Response(body=b'You are on protected page') response = web.Response(body=b'You are on protected page')
return response return response
def configure(self, app: web.Application) -> None: def configure(self, app):
router = app.router router = app.router
router.add_route('GET', '/', self.index, name='index') router.add_route('GET', '/', self.index, name='index')
router.add_route('POST', '/login', self.login, name='login') router.add_route('POST', '/login', self.login, name='login')

View File

@@ -1,5 +1,4 @@
import asyncio import asyncio
from typing import Any, Tuple
from aiohttp import web from aiohttp import web
from aiohttp_session import setup as setup_session from aiohttp_session import setup as setup_session
@@ -14,14 +13,14 @@ from demo.database_auth.db_auth import DBAuthorizationPolicy
from demo.database_auth.handlers import Web from demo.database_auth.handlers import Web
async def init(loop: asyncio.AbstractEventLoop) -> Tuple[Any, ...]: async def init(loop):
redis_pool = await create_pool(('localhost', 6379)) redis_pool = await create_pool(('localhost', 6379))
db_engine = await create_engine(user='aiohttp_security', db_engine = await create_engine(user='aiohttp_security',
password='aiohttp_security', password='aiohttp_security',
database='aiohttp_security', database='aiohttp_security',
host='127.0.0.1') host='127.0.0.1')
app = web.Application() app = web.Application()
app['db_engine'] = db_engine app.db_engine = db_engine
setup_session(app, RedisStorage(redis_pool)) setup_session(app, RedisStorage(redis_pool))
setup_security(app, setup_security(app,
SessionIdentityPolicy(), SessionIdentityPolicy(),
@@ -36,7 +35,7 @@ async def init(loop: asyncio.AbstractEventLoop) -> Tuple[Any, ...]:
return srv, app, handler return srv, app, handler
async def finalize(srv: Any, app: Any, handler: Any) -> None: async def finalize(srv, app, handler):
sock = srv.sockets[0] sock = srv.sockets[0]
app.loop.remove_reader(sock.fileno()) app.loop.remove_reader(sock.fileno())
sock.close() sock.close()
@@ -47,7 +46,7 @@ async def finalize(srv: Any, app: Any, handler: Any) -> None:
await app.finish() await app.finish()
def main() -> None: def main():
loop = asyncio.get_event_loop() loop = asyncio.get_event_loop()
srv, app, handler = loop.run_until_complete(init(loop)) srv, app, handler = loop.run_until_complete(init(loop))
try: try:

View File

@@ -1,25 +1,20 @@
from enum import Enum
from typing import Dict, Optional, Union
from aiohttp_security.abc import AbstractAuthorizationPolicy from aiohttp_security.abc import AbstractAuthorizationPolicy
from .users import User
class DictionaryAuthorizationPolicy(AbstractAuthorizationPolicy): class DictionaryAuthorizationPolicy(AbstractAuthorizationPolicy):
def __init__(self, user_map: Dict[str, User]): def __init__(self, user_map):
super().__init__() super().__init__()
self.user_map = user_map self.user_map = user_map
async def authorized_userid(self, identity: str) -> Optional[str]: async def authorized_userid(self, identity):
"""Retrieve authorized user id. """Retrieve authorized user id.
Return the user_id of the user identified by the identity Return the user_id of the user identified by the identity
or 'None' if no user exists related to the identity. or 'None' if no user exists related to the identity.
""" """
return identity if identity in self.user_map else None if identity in self.user_map:
return identity
async def permits(self, identity: str, permission: Union[str, Enum], async def permits(self, identity, permission, context=None):
context: None = None) -> bool:
"""Check user permissions. """Check user permissions.
Return True if the identity is allowed the permission in the Return True if the identity is allowed the permission in the
current context, else return False. current context, else return False.
@@ -31,7 +26,7 @@ class DictionaryAuthorizationPolicy(AbstractAuthorizationPolicy):
return permission in user.permissions return permission in user.permissions
async def check_credentials(user_map: Dict[str, User], username: str, password: str) -> bool: async def check_credentials(user_map, username, password):
user = user_map.get(username) user = user_map.get(username)
if not user: if not user:
return False return False

View File

@@ -1,5 +1,4 @@
from textwrap import dedent from textwrap import dedent
from typing import Dict, NoReturn
from aiohttp import web from aiohttp import web
@@ -9,7 +8,6 @@ from aiohttp_security import (
) )
from .authz import check_credentials from .authz import check_credentials
from .users import User
index_template = dedent(""" index_template = dedent("""
@@ -29,7 +27,7 @@ index_template = dedent("""
""") """)
async def index(request: web.Request) -> web.Response: async def index(request):
username = await authorized_userid(request) username = await authorized_userid(request)
if username: if username:
template = index_template.format( template = index_template.format(
@@ -42,26 +40,22 @@ async def index(request: web.Request) -> web.Response:
) )
async def login(request: web.Request) -> NoReturn: async def login(request):
user_map: Dict[str, User] = request.app['user_map'] response = web.HTTPFound('/')
invalid_response = web.HTTPUnauthorized(body='Invalid username / password combination')
form = await request.post() form = await request.post()
username = form.get('username') username = form.get('username')
password = form.get('password') password = form.get('password')
if not (isinstance(username, str) and isinstance(password, str)): verified = await check_credentials(
raise invalid_response request.app.user_map, username, password)
verified = await check_credentials(user_map, username, password)
if verified: if verified:
response = web.HTTPFound('/')
await remember(request, response, username) await remember(request, response, username)
raise response return response
raise invalid_response return web.HTTPUnauthorized(body='Invalid username / password combination')
async def logout(request: web.Request) -> web.Response: async def logout(request):
await check_authorized(request) await check_authorized(request)
response = web.Response( response = web.Response(
text='You have been logged out', text='You have been logged out',
@@ -71,7 +65,7 @@ async def logout(request: web.Request) -> web.Response:
return response return response
async def internal_page(request: web.Request) -> web.Response: async def internal_page(request):
await check_permission(request, 'public') await check_permission(request, 'public')
response = web.Response( response = web.Response(
text='This page is visible for all registered users', text='This page is visible for all registered users',
@@ -80,7 +74,7 @@ async def internal_page(request: web.Request) -> web.Response:
return response return response
async def protected_page(request: web.Request) -> web.Response: async def protected_page(request):
await check_permission(request, 'protected') await check_permission(request, 'protected')
response = web.Response( response = web.Response(
text='You are on protected page', text='You are on protected page',
@@ -89,7 +83,7 @@ async def protected_page(request: web.Request) -> web.Response:
return response return response
def configure_handlers(app: web.Application) -> None: def configure_handlers(app):
router = app.router router = app.router
router.add_get('/', index, name='index') router.add_get('/', index, name='index')
router.add_post('/login', login, name='login') router.add_post('/login', login, name='login')

View File

@@ -11,9 +11,9 @@ from demo.dictionary_auth.handlers import configure_handlers
from demo.dictionary_auth.users import user_map from demo.dictionary_auth.users import user_map
def make_app() -> web.Application: def make_app():
app = web.Application() app = web.Application()
app['user_map'] = user_map app.user_map = user_map
configure_handlers(app) configure_handlers(app)
# secret_key must be 32 url-safe base64-encoded bytes # secret_key must be 32 url-safe base64-encoded bytes

View File

@@ -1,11 +1,6 @@
from typing import NamedTuple, Tuple from collections import namedtuple
class User(NamedTuple):
username: str
password: str
permissions: Tuple[str, ...]
User = namedtuple('User', ['username', 'password', 'permissions'])
user_map = { user_map = {
user.username: user for user in [ user.username: user for user in [

View File

@@ -1,6 +1,3 @@
from enum import Enum
from typing import NoReturn, Optional, Union
from aiohttp import web from aiohttp import web
from aiohttp_session import SimpleCookieStorage, session_middleware from aiohttp_session import SimpleCookieStorage, session_middleware
from aiohttp_security import check_permission, \ from aiohttp_security import check_permission, \
@@ -14,15 +11,15 @@ from aiohttp_security.abc import AbstractAuthorizationPolicy
# For more complicated authorization policies see examples # For more complicated authorization policies see examples
# in the 'demo' directory. # in the 'demo' directory.
class SimpleJack_AuthorizationPolicy(AbstractAuthorizationPolicy): class SimpleJack_AuthorizationPolicy(AbstractAuthorizationPolicy):
async def authorized_userid(self, identity: str) -> Optional[str]: async def authorized_userid(self, identity):
"""Retrieve authorized user id. """Retrieve authorized user id.
Return the user_id of the user identified by the identity Return the user_id of the user identified by the identity
or 'None' if no user exists related to the identity. or 'None' if no user exists related to the identity.
""" """
return identity if identity == 'jack' else None if identity == 'jack':
return identity
async def permits(self, identity: str, permission: Union[str, Enum], async def permits(self, identity, permission, context=None):
context: None = None) -> bool:
"""Check user permissions. """Check user permissions.
Return True if the identity is allowed the permission Return True if the identity is allowed the permission
in the current context, else return False. in the current context, else return False.
@@ -30,7 +27,7 @@ class SimpleJack_AuthorizationPolicy(AbstractAuthorizationPolicy):
return identity == 'jack' and permission in ('listen',) return identity == 'jack' and permission in ('listen',)
async def handler_root(request: web.Request) -> web.Response: async def handler_root(request):
is_logged = not await is_anonymous(request) is_logged = not await is_anonymous(request)
return web.Response(text='''<html><head></head><body> return web.Response(text='''<html><head></head><body>
Hello, I'm Jack, I'm {logged} logged in.<br /><br /> Hello, I'm Jack, I'm {logged} logged in.<br /><br />
@@ -45,29 +42,29 @@ async def handler_root(request: web.Request) -> web.Response:
), content_type='text/html') ), content_type='text/html')
async def handler_login_jack(request: web.Request) -> NoReturn: async def handler_login_jack(request):
redirect_response = web.HTTPFound('/') redirect_response = web.HTTPFound('/')
await remember(request, redirect_response, 'jack') await remember(request, redirect_response, 'jack')
raise redirect_response raise redirect_response
async def handler_logout(request: web.Request) -> NoReturn: async def handler_logout(request):
redirect_response = web.HTTPFound('/') redirect_response = web.HTTPFound('/')
await forget(request, redirect_response) await forget(request, redirect_response)
raise redirect_response raise redirect_response
async def handler_listen(request: web.Request) -> web.Response: async def handler_listen(request):
await check_permission(request, 'listen') await check_permission(request, 'listen')
return web.Response(body="I can listen!") return web.Response(body="I can listen!")
async def handler_speak(request: web.Request) -> web.Response: async def handler_speak(request):
await check_permission(request, 'speak') await check_permission(request, 'speak')
return web.Response(body="I can speak!") return web.Response(body="I can speak!")
async def make_app() -> web.Application: async def make_app():
# #
# WARNING!!! # WARNING!!!
# Never use SimpleCookieStorage on production!!! # Never use SimpleCookieStorage on production!!!

View File

@@ -13,6 +13,6 @@ aioredis==1.3.1
hiredis==1.1.0 hiredis==1.1.0
passlib==1.7.4 passlib==1.7.4
cryptography==3.3.1 cryptography==3.3.1
aiohttp==3.7.3 aiohttp<3.7
pytest-aiohttp==0.3.0 pytest-aiohttp==0.3.0
pyjwt==1.7.1 pyjwt==1.7.1

View File

@@ -2,3 +2,4 @@
testpaths = tests testpaths = tests
filterwarnings= filterwarnings=
error error
ignore:The loop argument:DeprecationWarning

View File

View File

@@ -67,7 +67,7 @@ async def test_identify_broken_scheme(loop, make_token, aiohttp_client):
try: try:
await policy.identify(request) await policy.identify(request)
except ValueError as exc: except ValueError as exc:
raise web.HTTPBadRequest(reason=str(exc)) raise web.HTTPBadRequest(reason=exc)
return web.Response() return web.Response()