Add type annotations.
This commit is contained in:
0
demo/__init__.py
Normal file
0
demo/__init__.py
Normal file
0
demo/database_auth/__init__.py
Normal file
0
demo/database_auth/__init__.py
Normal file
@@ -1,5 +1,7 @@
|
||||
import sqlalchemy as sa
|
||||
from enum import Enum
|
||||
from typing import Any, Optional, Union
|
||||
|
||||
import sqlalchemy as sa
|
||||
from aiohttp_security.abc import AbstractAuthorizationPolicy
|
||||
from passlib.hash import sha256_crypt
|
||||
|
||||
@@ -7,13 +9,13 @@ from . import db
|
||||
|
||||
|
||||
class DBAuthorizationPolicy(AbstractAuthorizationPolicy):
|
||||
def __init__(self, dbengine):
|
||||
def __init__(self, dbengine: Any):
|
||||
self.dbengine = dbengine
|
||||
|
||||
async def authorized_userid(self, identity):
|
||||
async def authorized_userid(self, identity: str) -> Optional[str]:
|
||||
async with self.dbengine.acquire() as conn:
|
||||
where = sa.and_(db.users.c.login == identity,
|
||||
sa.not_(db.users.c.disabled))
|
||||
sa.not_(db.users.c.disabled)) # type: ignore[no-untyped-call]
|
||||
query = db.users.count().where(where)
|
||||
ret = await conn.scalar(query)
|
||||
if ret:
|
||||
@@ -21,13 +23,11 @@ class DBAuthorizationPolicy(AbstractAuthorizationPolicy):
|
||||
else:
|
||||
return None
|
||||
|
||||
async def permits(self, identity, permission, context=None):
|
||||
if identity is None:
|
||||
return False
|
||||
|
||||
async def permits(self, identity: str, permission: Union[str, Enum],
|
||||
context: None = None) -> bool:
|
||||
async with self.dbengine.acquire() as conn:
|
||||
where = sa.and_(db.users.c.login == identity,
|
||||
sa.not_(db.users.c.disabled))
|
||||
sa.not_(db.users.c.disabled)) # type: ignore[no-untyped-call]
|
||||
query = db.users.select().where(where)
|
||||
ret = await conn.execute(query)
|
||||
user = await ret.fetchone()
|
||||
@@ -49,14 +49,14 @@ class DBAuthorizationPolicy(AbstractAuthorizationPolicy):
|
||||
return False
|
||||
|
||||
|
||||
async def check_credentials(db_engine, username, password):
|
||||
async def check_credentials(db_engine: Any, username: str, password: str) -> bool:
|
||||
async with db_engine.acquire() as conn:
|
||||
where = sa.and_(db.users.c.login == username,
|
||||
sa.not_(db.users.c.disabled))
|
||||
sa.not_(db.users.c.disabled)) # type: ignore[no-untyped-call]
|
||||
query = db.users.select().where(where)
|
||||
ret = await conn.execute(query)
|
||||
user = await ret.fetchone()
|
||||
if user is not None:
|
||||
hashed = user[2]
|
||||
return sha256_crypt.verify(password, hashed)
|
||||
return sha256_crypt.verify(password, hashed) # type: ignore[no-any-return]
|
||||
return False
|
||||
|
@@ -1,4 +1,5 @@
|
||||
from textwrap import dedent
|
||||
from typing import NoReturn
|
||||
|
||||
from aiohttp import web
|
||||
|
||||
@@ -27,7 +28,7 @@ class Web(object):
|
||||
</body>
|
||||
""")
|
||||
|
||||
async def index(self, request):
|
||||
async def index(self, request: web.Request) -> web.Response:
|
||||
username = await authorized_userid(request)
|
||||
if username:
|
||||
template = self.index_template.format(
|
||||
@@ -37,37 +38,41 @@ class Web(object):
|
||||
response = web.Response(body=template.encode())
|
||||
return response
|
||||
|
||||
async def login(self, request):
|
||||
response = web.HTTPFound('/')
|
||||
async def login(self, request: web.Request) -> NoReturn:
|
||||
invalid_resp = web.HTTPUnauthorized(body=b'Invalid username/password combination')
|
||||
form = await request.post()
|
||||
login = form.get('login')
|
||||
password = form.get('password')
|
||||
db_engine = request.app.db_engine
|
||||
db_engine = request.app['db_engine']
|
||||
|
||||
if not (isinstance(login, str) and isinstance(password, str)):
|
||||
raise invalid_resp
|
||||
|
||||
if await check_credentials(db_engine, login, password):
|
||||
response = web.HTTPFound('/')
|
||||
await remember(request, response, login)
|
||||
raise response
|
||||
|
||||
raise web.HTTPUnauthorized(
|
||||
body=b'Invalid username/password combination')
|
||||
raise invalid_resp
|
||||
|
||||
async def logout(self, request):
|
||||
async def logout(self, request: web.Request) -> web.Response:
|
||||
await check_authorized(request)
|
||||
response = web.Response(body=b'You have been logged out')
|
||||
await forget(request, response)
|
||||
return response
|
||||
|
||||
async def internal_page(self, request):
|
||||
async def internal_page(self, request: web.Request) -> web.Response:
|
||||
await check_permission(request, 'public')
|
||||
response = web.Response(
|
||||
body=b'This page is visible for all registered users')
|
||||
return response
|
||||
|
||||
async def protected_page(self, request):
|
||||
async def protected_page(self, request: web.Request) -> web.Response:
|
||||
await check_permission(request, 'protected')
|
||||
response = web.Response(body=b'You are on protected page')
|
||||
return response
|
||||
|
||||
def configure(self, app):
|
||||
def configure(self, app: web.Application) -> None:
|
||||
router = app.router
|
||||
router.add_route('GET', '/', self.index, name='index')
|
||||
router.add_route('POST', '/login', self.login, name='login')
|
||||
|
@@ -1,4 +1,5 @@
|
||||
import asyncio
|
||||
from typing import Any, Tuple
|
||||
|
||||
from aiohttp import web
|
||||
from aiohttp_session import setup as setup_session
|
||||
@@ -13,14 +14,14 @@ from demo.database_auth.db_auth import DBAuthorizationPolicy
|
||||
from demo.database_auth.handlers import Web
|
||||
|
||||
|
||||
async def init(loop):
|
||||
async def init(loop: asyncio.AbstractEventLoop) -> Tuple[Any, ...]:
|
||||
redis_pool = await create_pool(('localhost', 6379))
|
||||
db_engine = await create_engine(user='aiohttp_security',
|
||||
password='aiohttp_security',
|
||||
database='aiohttp_security',
|
||||
host='127.0.0.1')
|
||||
app = web.Application()
|
||||
app.db_engine = db_engine
|
||||
app['db_engine'] = db_engine
|
||||
setup_session(app, RedisStorage(redis_pool))
|
||||
setup_security(app,
|
||||
SessionIdentityPolicy(),
|
||||
@@ -35,7 +36,7 @@ async def init(loop):
|
||||
return srv, app, handler
|
||||
|
||||
|
||||
async def finalize(srv, app, handler):
|
||||
async def finalize(srv: Any, app: Any, handler: Any) -> None:
|
||||
sock = srv.sockets[0]
|
||||
app.loop.remove_reader(sock.fileno())
|
||||
sock.close()
|
||||
@@ -46,7 +47,7 @@ async def finalize(srv, app, handler):
|
||||
await app.finish()
|
||||
|
||||
|
||||
def main():
|
||||
def main() -> None:
|
||||
loop = asyncio.get_event_loop()
|
||||
srv, app, handler = loop.run_until_complete(init(loop))
|
||||
try:
|
||||
|
0
demo/dictionary_auth/__init__.py
Normal file
0
demo/dictionary_auth/__init__.py
Normal file
@@ -1,20 +1,25 @@
|
||||
from enum import Enum
|
||||
from typing import Dict, Optional, Union
|
||||
|
||||
from aiohttp_security.abc import AbstractAuthorizationPolicy
|
||||
|
||||
from .users import User
|
||||
|
||||
|
||||
class DictionaryAuthorizationPolicy(AbstractAuthorizationPolicy):
|
||||
def __init__(self, user_map):
|
||||
def __init__(self, user_map: Dict[str, User]):
|
||||
super().__init__()
|
||||
self.user_map = user_map
|
||||
|
||||
async def authorized_userid(self, identity):
|
||||
async def authorized_userid(self, identity: str) -> Optional[str]:
|
||||
"""Retrieve authorized user id.
|
||||
Return the user_id of the user identified by the identity
|
||||
or 'None' if no user exists related to the identity.
|
||||
"""
|
||||
if identity in self.user_map:
|
||||
return identity
|
||||
return identity if identity in self.user_map else None
|
||||
|
||||
async def permits(self, identity, permission, context=None):
|
||||
async def permits(self, identity: str, permission: Union[str, Enum],
|
||||
context: None = None) -> bool:
|
||||
"""Check user permissions.
|
||||
Return True if the identity is allowed the permission in the
|
||||
current context, else return False.
|
||||
@@ -26,7 +31,7 @@ class DictionaryAuthorizationPolicy(AbstractAuthorizationPolicy):
|
||||
return permission in user.permissions
|
||||
|
||||
|
||||
async def check_credentials(user_map, username, password):
|
||||
async def check_credentials(user_map: Dict[str, User], username: str, password: str) -> bool:
|
||||
user = user_map.get(username)
|
||||
if not user:
|
||||
return False
|
||||
|
@@ -1,4 +1,5 @@
|
||||
from textwrap import dedent
|
||||
from typing import Dict, NoReturn
|
||||
|
||||
from aiohttp import web
|
||||
|
||||
@@ -8,6 +9,7 @@ from aiohttp_security import (
|
||||
)
|
||||
|
||||
from .authz import check_credentials
|
||||
from .users import User
|
||||
|
||||
|
||||
index_template = dedent("""
|
||||
@@ -27,7 +29,7 @@ index_template = dedent("""
|
||||
""")
|
||||
|
||||
|
||||
async def index(request):
|
||||
async def index(request: web.Request) -> web.Response:
|
||||
username = await authorized_userid(request)
|
||||
if username:
|
||||
template = index_template.format(
|
||||
@@ -40,22 +42,26 @@ async def index(request):
|
||||
)
|
||||
|
||||
|
||||
async def login(request):
|
||||
response = web.HTTPFound('/')
|
||||
async def login(request: web.Request) -> NoReturn:
|
||||
user_map: Dict[str, User] = request.app['user_map']
|
||||
invalid_response = web.HTTPUnauthorized(body='Invalid username / password combination')
|
||||
form = await request.post()
|
||||
username = form.get('username')
|
||||
password = form.get('password')
|
||||
|
||||
verified = await check_credentials(
|
||||
request.app.user_map, username, password)
|
||||
if not (isinstance(username, str) and isinstance(password, str)):
|
||||
raise invalid_response
|
||||
|
||||
verified = await check_credentials(user_map, username, password)
|
||||
if verified:
|
||||
response = web.HTTPFound('/')
|
||||
await remember(request, response, username)
|
||||
return response
|
||||
raise response
|
||||
|
||||
return web.HTTPUnauthorized(body='Invalid username / password combination')
|
||||
raise invalid_response
|
||||
|
||||
|
||||
async def logout(request):
|
||||
async def logout(request: web.Request) -> web.Response:
|
||||
await check_authorized(request)
|
||||
response = web.Response(
|
||||
text='You have been logged out',
|
||||
@@ -65,7 +71,7 @@ async def logout(request):
|
||||
return response
|
||||
|
||||
|
||||
async def internal_page(request):
|
||||
async def internal_page(request: web.Request) -> web.Response:
|
||||
await check_permission(request, 'public')
|
||||
response = web.Response(
|
||||
text='This page is visible for all registered users',
|
||||
@@ -74,7 +80,7 @@ async def internal_page(request):
|
||||
return response
|
||||
|
||||
|
||||
async def protected_page(request):
|
||||
async def protected_page(request: web.Request) -> web.Response:
|
||||
await check_permission(request, 'protected')
|
||||
response = web.Response(
|
||||
text='You are on protected page',
|
||||
@@ -83,7 +89,7 @@ async def protected_page(request):
|
||||
return response
|
||||
|
||||
|
||||
def configure_handlers(app):
|
||||
def configure_handlers(app: web.Application) -> None:
|
||||
router = app.router
|
||||
router.add_get('/', index, name='index')
|
||||
router.add_post('/login', login, name='login')
|
||||
|
@@ -11,9 +11,9 @@ from demo.dictionary_auth.handlers import configure_handlers
|
||||
from demo.dictionary_auth.users import user_map
|
||||
|
||||
|
||||
def make_app():
|
||||
def make_app() -> web.Application:
|
||||
app = web.Application()
|
||||
app.user_map = user_map
|
||||
app['user_map'] = user_map
|
||||
configure_handlers(app)
|
||||
|
||||
# secret_key must be 32 url-safe base64-encoded bytes
|
||||
|
@@ -1,6 +1,11 @@
|
||||
from collections import namedtuple
|
||||
from typing import NamedTuple, Tuple
|
||||
|
||||
|
||||
class User(NamedTuple):
|
||||
username: str
|
||||
password: str
|
||||
permissions: Tuple[str, ...]
|
||||
|
||||
User = namedtuple('User', ['username', 'password', 'permissions'])
|
||||
|
||||
user_map = {
|
||||
user.username: user for user in [
|
||||
|
@@ -1,3 +1,6 @@
|
||||
from enum import Enum
|
||||
from typing import NoReturn, Optional, Union
|
||||
|
||||
from aiohttp import web
|
||||
from aiohttp_session import SimpleCookieStorage, session_middleware
|
||||
from aiohttp_security import check_permission, \
|
||||
@@ -11,15 +14,15 @@ from aiohttp_security.abc import AbstractAuthorizationPolicy
|
||||
# For more complicated authorization policies see examples
|
||||
# in the 'demo' directory.
|
||||
class SimpleJack_AuthorizationPolicy(AbstractAuthorizationPolicy):
|
||||
async def authorized_userid(self, identity):
|
||||
async def authorized_userid(self, identity: str) -> Optional[str]:
|
||||
"""Retrieve authorized user id.
|
||||
Return the user_id of the user identified by the identity
|
||||
or 'None' if no user exists related to the identity.
|
||||
"""
|
||||
if identity == 'jack':
|
||||
return identity
|
||||
return identity if identity == 'jack' else None
|
||||
|
||||
async def permits(self, identity, permission, context=None):
|
||||
async def permits(self, identity: str, permission: Union[str, Enum],
|
||||
context: None = None) -> bool:
|
||||
"""Check user permissions.
|
||||
Return True if the identity is allowed the permission
|
||||
in the current context, else return False.
|
||||
@@ -27,7 +30,7 @@ class SimpleJack_AuthorizationPolicy(AbstractAuthorizationPolicy):
|
||||
return identity == 'jack' and permission in ('listen',)
|
||||
|
||||
|
||||
async def handler_root(request):
|
||||
async def handler_root(request: web.Request) -> web.Response:
|
||||
is_logged = not await is_anonymous(request)
|
||||
return web.Response(text='''<html><head></head><body>
|
||||
Hello, I'm Jack, I'm {logged} logged in.<br /><br />
|
||||
@@ -42,29 +45,29 @@ async def handler_root(request):
|
||||
), content_type='text/html')
|
||||
|
||||
|
||||
async def handler_login_jack(request):
|
||||
async def handler_login_jack(request: web.Request) -> NoReturn:
|
||||
redirect_response = web.HTTPFound('/')
|
||||
await remember(request, redirect_response, 'jack')
|
||||
raise redirect_response
|
||||
|
||||
|
||||
async def handler_logout(request):
|
||||
async def handler_logout(request: web.Request) -> NoReturn:
|
||||
redirect_response = web.HTTPFound('/')
|
||||
await forget(request, redirect_response)
|
||||
raise redirect_response
|
||||
|
||||
|
||||
async def handler_listen(request):
|
||||
async def handler_listen(request: web.Request) -> web.Response:
|
||||
await check_permission(request, 'listen')
|
||||
return web.Response(body="I can listen!")
|
||||
|
||||
|
||||
async def handler_speak(request):
|
||||
async def handler_speak(request: web.Request) -> web.Response:
|
||||
await check_permission(request, 'speak')
|
||||
return web.Response(body="I can speak!")
|
||||
|
||||
|
||||
async def make_app():
|
||||
async def make_app() -> web.Application:
|
||||
#
|
||||
# WARNING!!!
|
||||
# Never use SimpleCookieStorage on production!!!
|
||||
|
Reference in New Issue
Block a user