Switch to async/await syntax
This commit is contained in:
@@ -1,5 +1,3 @@
|
||||
import asyncio
|
||||
|
||||
import sqlalchemy as sa
|
||||
|
||||
from aiohttp_security.abc import AbstractAuthorizationPolicy
|
||||
@@ -12,29 +10,27 @@ class DBAuthorizationPolicy(AbstractAuthorizationPolicy):
|
||||
def __init__(self, dbengine):
|
||||
self.dbengine = dbengine
|
||||
|
||||
@asyncio.coroutine
|
||||
def authorized_userid(self, identity):
|
||||
with (yield from self.dbengine) as conn:
|
||||
async with self.dbengine as conn:
|
||||
where = sa.and_(db.users.c.login == identity,
|
||||
sa.not_(db.users.c.disabled))
|
||||
query = db.users.count().where(where)
|
||||
ret = yield from conn.scalar(query)
|
||||
ret = await conn.scalar(query)
|
||||
if ret:
|
||||
return identity
|
||||
else:
|
||||
return None
|
||||
|
||||
@asyncio.coroutine
|
||||
def permits(self, identity, permission, context=None):
|
||||
async def permits(self, identity, permission, context=None):
|
||||
if identity is None:
|
||||
return False
|
||||
|
||||
with (yield from self.dbengine) as conn:
|
||||
async with self.dbengine as conn:
|
||||
where = sa.and_(db.users.c.login == identity,
|
||||
sa.not_(db.users.c.disabled))
|
||||
query = db.users.select().where(where)
|
||||
ret = yield from conn.execute(query)
|
||||
user = yield from ret.fetchone()
|
||||
ret = await conn.execute(query)
|
||||
user = await ret.fetchone()
|
||||
if user is not None:
|
||||
user_id = user[0]
|
||||
is_superuser = user[3]
|
||||
@@ -43,8 +39,8 @@ class DBAuthorizationPolicy(AbstractAuthorizationPolicy):
|
||||
|
||||
where = db.permissions.c.user_id == user_id
|
||||
query = db.permissions.select().where(where)
|
||||
ret = yield from conn.execute(query)
|
||||
result = yield from ret.fetchall()
|
||||
ret = await conn.execute(query)
|
||||
result = await ret.fetchall()
|
||||
if ret is not None:
|
||||
for record in result:
|
||||
if record.perm_name == permission:
|
||||
@@ -53,14 +49,13 @@ class DBAuthorizationPolicy(AbstractAuthorizationPolicy):
|
||||
return False
|
||||
|
||||
|
||||
@asyncio.coroutine
|
||||
def check_credentials(db_engine, username, password):
|
||||
with (yield from db_engine) as conn:
|
||||
async def check_credentials(db_engine, username, password):
|
||||
async with db_engine as conn:
|
||||
where = sa.and_(db.users.c.login == username,
|
||||
sa.not_(db.users.c.disabled))
|
||||
query = db.users.select().where(where)
|
||||
ret = yield from conn.execute(query)
|
||||
user = yield from ret.fetchone()
|
||||
ret = await conn.execute(query)
|
||||
user = await ret.fetchone()
|
||||
if user is not None:
|
||||
hash = user[2]
|
||||
return sha256_crypt.verify(password, hash)
|
||||
|
@@ -1,4 +1,3 @@
|
||||
import asyncio
|
||||
import functools
|
||||
|
||||
from aiohttp import web
|
||||
@@ -10,14 +9,13 @@ from .db_auth import check_credentials
|
||||
|
||||
def require(permission):
|
||||
def wrapper(f):
|
||||
@asyncio.coroutine
|
||||
@functools.wraps(f)
|
||||
def wrapped(self, request):
|
||||
has_perm = yield from permits(request, permission)
|
||||
async def wrapped(self, request):
|
||||
has_perm = await permits(request, permission)
|
||||
if not has_perm:
|
||||
message = 'User has no permission {}'.format(permission)
|
||||
raise web.HTTPForbidden(body=message.encode())
|
||||
return (yield from f(self, request))
|
||||
return await f(self, request)
|
||||
return wrapped
|
||||
return wrapper
|
||||
|
||||
@@ -40,9 +38,8 @@ class Web(object):
|
||||
</body>
|
||||
"""
|
||||
|
||||
@asyncio.coroutine
|
||||
def index(self, request):
|
||||
username = yield from authorized_userid(request)
|
||||
async def index(self, request):
|
||||
username = await authorized_userid(request)
|
||||
if username:
|
||||
template = self.index_template.format(
|
||||
message='Hello, {username}!'.format(username=username))
|
||||
@@ -51,37 +48,33 @@ class Web(object):
|
||||
response = web.Response(body=template.encode())
|
||||
return response
|
||||
|
||||
@asyncio.coroutine
|
||||
def login(self, request):
|
||||
async def login(self, request):
|
||||
response = web.HTTPFound('/')
|
||||
form = yield from request.post()
|
||||
form = await request.post()
|
||||
login = form.get('login')
|
||||
password = form.get('password')
|
||||
db_engine = request.app.db_engine
|
||||
if (yield from check_credentials(db_engine, login, password)):
|
||||
yield from remember(request, response, login)
|
||||
if await check_credentials(db_engine, login, password):
|
||||
await remember(request, response, login)
|
||||
return response
|
||||
|
||||
return web.HTTPUnauthorized(
|
||||
body=b'Invalid username/password combination')
|
||||
|
||||
@require('public')
|
||||
@asyncio.coroutine
|
||||
def logout(self, request):
|
||||
async def logout(self, request):
|
||||
response = web.Response(body=b'You have been logged out')
|
||||
yield from forget(request, response)
|
||||
await forget(request, response)
|
||||
return response
|
||||
|
||||
@require('public')
|
||||
@asyncio.coroutine
|
||||
def internal_page(self, request):
|
||||
async def internal_page(self, request):
|
||||
response = web.Response(
|
||||
body=b'This page is visible for all registered users')
|
||||
return response
|
||||
|
||||
@require('protected')
|
||||
@asyncio.coroutine
|
||||
def protected_page(self, request):
|
||||
async def protected_page(self, request):
|
||||
response = web.Response(body=b'You are on protected page')
|
||||
return response
|
||||
|
||||
|
@@ -13,10 +13,9 @@ from demo.db_auth import DBAuthorizationPolicy
|
||||
from demo.handlers import Web
|
||||
|
||||
|
||||
@asyncio.coroutine
|
||||
def init(loop):
|
||||
redis_pool = yield from create_pool(('localhost', 6379))
|
||||
db_engine = yield from create_engine(user='aiohttp_security',
|
||||
redis_pool = await create_pool(('localhost', 6379))
|
||||
db_engine = await create_engine(user='aiohttp_security',
|
||||
password='aiohttp_security',
|
||||
database='aiohttp_security',
|
||||
host='127.0.0.1')
|
||||
@@ -31,21 +30,20 @@ def init(loop):
|
||||
web_handlers.configure(app)
|
||||
|
||||
handler = app.make_handler()
|
||||
srv = yield from loop.create_server(handler, '127.0.0.1', 8080)
|
||||
srv = await loop.create_server(handler, '127.0.0.1', 8080)
|
||||
print('Server started at http://127.0.0.1:8080')
|
||||
return srv, app, handler
|
||||
|
||||
|
||||
@asyncio.coroutine
|
||||
def finalize(srv, app, handler):
|
||||
async def finalize(srv, app, handler):
|
||||
sock = srv.sockets[0]
|
||||
app.loop.remove_reader(sock.fileno())
|
||||
sock.close()
|
||||
|
||||
yield from handler.finish_connections(1.0)
|
||||
await handler.finish_connections(1.0)
|
||||
srv.close()
|
||||
yield from srv.wait_closed()
|
||||
yield from app.finish()
|
||||
await srv.wait_closed()
|
||||
await app.finish()
|
||||
|
||||
|
||||
def main():
|
||||
|
Reference in New Issue
Block a user