Add type annotations.

This commit is contained in:
Sam Bull
2020-12-18 17:58:38 +00:00
parent 2247eb72f9
commit b3b5367460
20 changed files with 200 additions and 112 deletions

View File

View File

@@ -1,5 +1,7 @@
import sqlalchemy as sa
from enum import Enum
from typing import Any, Optional, Union
import sqlalchemy as sa
from aiohttp_security.abc import AbstractAuthorizationPolicy
from passlib.hash import sha256_crypt
@@ -7,13 +9,13 @@ from . import db
class DBAuthorizationPolicy(AbstractAuthorizationPolicy):
def __init__(self, dbengine):
def __init__(self, dbengine: Any):
self.dbengine = dbengine
async def authorized_userid(self, identity):
async def authorized_userid(self, identity: str) -> Optional[str]:
async with self.dbengine.acquire() as conn:
where = sa.and_(db.users.c.login == identity,
sa.not_(db.users.c.disabled))
sa.not_(db.users.c.disabled)) # type: ignore[no-untyped-call]
query = db.users.count().where(where)
ret = await conn.scalar(query)
if ret:
@@ -21,13 +23,11 @@ class DBAuthorizationPolicy(AbstractAuthorizationPolicy):
else:
return None
async def permits(self, identity, permission, context=None):
if identity is None:
return False
async def permits(self, identity: str, permission: Union[str, Enum],
context: None = None) -> bool:
async with self.dbengine.acquire() as conn:
where = sa.and_(db.users.c.login == identity,
sa.not_(db.users.c.disabled))
sa.not_(db.users.c.disabled)) # type: ignore[no-untyped-call]
query = db.users.select().where(where)
ret = await conn.execute(query)
user = await ret.fetchone()
@@ -49,14 +49,14 @@ class DBAuthorizationPolicy(AbstractAuthorizationPolicy):
return False
async def check_credentials(db_engine, username, password):
async def check_credentials(db_engine: Any, username: str, password: str) -> bool:
async with db_engine.acquire() as conn:
where = sa.and_(db.users.c.login == username,
sa.not_(db.users.c.disabled))
sa.not_(db.users.c.disabled)) # type: ignore[no-untyped-call]
query = db.users.select().where(where)
ret = await conn.execute(query)
user = await ret.fetchone()
if user is not None:
hashed = user[2]
return sha256_crypt.verify(password, hashed)
return sha256_crypt.verify(password, hashed) # type: ignore[no-any-return]
return False

View File

@@ -1,4 +1,5 @@
from textwrap import dedent
from typing import NoReturn
from aiohttp import web
@@ -27,7 +28,7 @@ class Web(object):
</body>
""")
async def index(self, request):
async def index(self, request: web.Request) -> web.Response:
username = await authorized_userid(request)
if username:
template = self.index_template.format(
@@ -37,37 +38,41 @@ class Web(object):
response = web.Response(body=template.encode())
return response
async def login(self, request):
response = web.HTTPFound('/')
async def login(self, request: web.Request) -> NoReturn:
invalid_resp = web.HTTPUnauthorized(body=b'Invalid username/password combination')
form = await request.post()
login = form.get('login')
password = form.get('password')
db_engine = request.app.db_engine
db_engine = request.app['db_engine']
if not (isinstance(login, str) and isinstance(password, str)):
raise invalid_resp
if await check_credentials(db_engine, login, password):
response = web.HTTPFound('/')
await remember(request, response, login)
raise response
raise web.HTTPUnauthorized(
body=b'Invalid username/password combination')
raise invalid_resp
async def logout(self, request):
async def logout(self, request: web.Request) -> web.Response:
await check_authorized(request)
response = web.Response(body=b'You have been logged out')
await forget(request, response)
return response
async def internal_page(self, request):
async def internal_page(self, request: web.Request) -> web.Response:
await check_permission(request, 'public')
response = web.Response(
body=b'This page is visible for all registered users')
return response
async def protected_page(self, request):
async def protected_page(self, request: web.Request) -> web.Response:
await check_permission(request, 'protected')
response = web.Response(body=b'You are on protected page')
return response
def configure(self, app):
def configure(self, app: web.Application) -> None:
router = app.router
router.add_route('GET', '/', self.index, name='index')
router.add_route('POST', '/login', self.login, name='login')

View File

@@ -1,4 +1,5 @@
import asyncio
from typing import Any, Tuple
from aiohttp import web
from aiohttp_session import setup as setup_session
@@ -13,14 +14,14 @@ from demo.database_auth.db_auth import DBAuthorizationPolicy
from demo.database_auth.handlers import Web
async def init(loop):
async def init(loop: asyncio.AbstractEventLoop) -> Tuple[Any, ...]:
redis_pool = await create_pool(('localhost', 6379))
db_engine = await create_engine(user='aiohttp_security',
password='aiohttp_security',
database='aiohttp_security',
host='127.0.0.1')
app = web.Application()
app.db_engine = db_engine
app['db_engine'] = db_engine
setup_session(app, RedisStorage(redis_pool))
setup_security(app,
SessionIdentityPolicy(),
@@ -35,7 +36,7 @@ async def init(loop):
return srv, app, handler
async def finalize(srv, app, handler):
async def finalize(srv: Any, app: Any, handler: Any) -> None:
sock = srv.sockets[0]
app.loop.remove_reader(sock.fileno())
sock.close()
@@ -46,7 +47,7 @@ async def finalize(srv, app, handler):
await app.finish()
def main():
def main() -> None:
loop = asyncio.get_event_loop()
srv, app, handler = loop.run_until_complete(init(loop))
try: