Add type annotations.

This commit is contained in:
Sam Bull
2020-12-18 17:58:38 +00:00
parent 2247eb72f9
commit b3b5367460
20 changed files with 200 additions and 112 deletions

View File

View File

@@ -1,20 +1,25 @@
from enum import Enum
from typing import Dict, Optional, Union
from aiohttp_security.abc import AbstractAuthorizationPolicy
from .users import User
class DictionaryAuthorizationPolicy(AbstractAuthorizationPolicy):
def __init__(self, user_map):
def __init__(self, user_map: Dict[str, User]):
super().__init__()
self.user_map = user_map
async def authorized_userid(self, identity):
async def authorized_userid(self, identity: str) -> Optional[str]:
"""Retrieve authorized user id.
Return the user_id of the user identified by the identity
or 'None' if no user exists related to the identity.
"""
if identity in self.user_map:
return identity
return identity if identity in self.user_map else None
async def permits(self, identity, permission, context=None):
async def permits(self, identity: str, permission: Union[str, Enum],
context: None = None) -> bool:
"""Check user permissions.
Return True if the identity is allowed the permission in the
current context, else return False.
@@ -26,7 +31,7 @@ class DictionaryAuthorizationPolicy(AbstractAuthorizationPolicy):
return permission in user.permissions
async def check_credentials(user_map, username, password):
async def check_credentials(user_map: Dict[str, User], username: str, password: str) -> bool:
user = user_map.get(username)
if not user:
return False

View File

@@ -1,4 +1,5 @@
from textwrap import dedent
from typing import Dict, NoReturn
from aiohttp import web
@@ -8,6 +9,7 @@ from aiohttp_security import (
)
from .authz import check_credentials
from .users import User
index_template = dedent("""
@@ -27,7 +29,7 @@ index_template = dedent("""
""")
async def index(request):
async def index(request: web.Request) -> web.Response:
username = await authorized_userid(request)
if username:
template = index_template.format(
@@ -40,22 +42,26 @@ async def index(request):
)
async def login(request):
response = web.HTTPFound('/')
async def login(request: web.Request) -> NoReturn:
user_map: Dict[str, User] = request.app['user_map']
invalid_response = web.HTTPUnauthorized(body='Invalid username / password combination')
form = await request.post()
username = form.get('username')
password = form.get('password')
verified = await check_credentials(
request.app.user_map, username, password)
if not (isinstance(username, str) and isinstance(password, str)):
raise invalid_response
verified = await check_credentials(user_map, username, password)
if verified:
response = web.HTTPFound('/')
await remember(request, response, username)
return response
raise response
return web.HTTPUnauthorized(body='Invalid username / password combination')
raise invalid_response
async def logout(request):
async def logout(request: web.Request) -> web.Response:
await check_authorized(request)
response = web.Response(
text='You have been logged out',
@@ -65,7 +71,7 @@ async def logout(request):
return response
async def internal_page(request):
async def internal_page(request: web.Request) -> web.Response:
await check_permission(request, 'public')
response = web.Response(
text='This page is visible for all registered users',
@@ -74,7 +80,7 @@ async def internal_page(request):
return response
async def protected_page(request):
async def protected_page(request: web.Request) -> web.Response:
await check_permission(request, 'protected')
response = web.Response(
text='You are on protected page',
@@ -83,7 +89,7 @@ async def protected_page(request):
return response
def configure_handlers(app):
def configure_handlers(app: web.Application) -> None:
router = app.router
router.add_get('/', index, name='index')
router.add_post('/login', login, name='login')

View File

@@ -11,9 +11,9 @@ from demo.dictionary_auth.handlers import configure_handlers
from demo.dictionary_auth.users import user_map
def make_app():
def make_app() -> web.Application:
app = web.Application()
app.user_map = user_map
app['user_map'] = user_map
configure_handlers(app)
# secret_key must be 32 url-safe base64-encoded bytes

View File

@@ -1,6 +1,11 @@
from collections import namedtuple
from typing import NamedTuple, Tuple
class User(NamedTuple):
username: str
password: str
permissions: Tuple[str, ...]
User = namedtuple('User', ['username', 'password', 'permissions'])
user_map = {
user.username: user for user in [