Switch to async/await syntax

This commit is contained in:
Andrew Svetlov 2017-12-13 16:51:46 +02:00
parent b9dee120c3
commit 5b2ff779c3
16 changed files with 257 additions and 393 deletions

View File

@ -1,6 +1,5 @@
language: python language: python
python: python:
- 3.4.3
- 3.5 - 3.5
- 3.6 - 3.6
- 3.7-dev - 3.7-dev

View File

@ -1,21 +1,18 @@
import abc import abc
import asyncio
# see http://plope.com/pyramid_auth_design_api_postmortem # see http://plope.com/pyramid_auth_design_api_postmortem
class AbstractIdentityPolicy(metaclass=abc.ABCMeta): class AbstractIdentityPolicy(metaclass=abc.ABCMeta):
@asyncio.coroutine
@abc.abstractmethod @abc.abstractmethod
def identify(self, request): async def identify(self, request):
"""Return the claimed identity of the user associated request or """Return the claimed identity of the user associated request or
``None`` if no identity can be found associated with the request.""" ``None`` if no identity can be found associated with the request."""
pass pass
@asyncio.coroutine
@abc.abstractmethod @abc.abstractmethod
def remember(self, request, response, identity, **kwargs): async def remember(self, request, response, identity, **kwargs):
"""Remember identity. """Remember identity.
Modify response object by filling it's headers with remembered user. Modify response object by filling it's headers with remembered user.
@ -25,9 +22,8 @@ class AbstractIdentityPolicy(metaclass=abc.ABCMeta):
""" """
pass pass
@asyncio.coroutine
@abc.abstractmethod @abc.abstractmethod
def forget(self, request, response): async def forget(self, request, response):
""" Modify response which can be used to 'forget' the """ Modify response which can be used to 'forget' the
current identity on subsequent requests.""" current identity on subsequent requests."""
pass pass
@ -35,9 +31,8 @@ class AbstractIdentityPolicy(metaclass=abc.ABCMeta):
class AbstractAuthorizationPolicy(metaclass=abc.ABCMeta): class AbstractAuthorizationPolicy(metaclass=abc.ABCMeta):
@asyncio.coroutine
@abc.abstractmethod @abc.abstractmethod
def permits(self, identity, permission, context=None): async def permits(self, identity, permission, context=None):
"""Check user permissions. """Check user permissions.
Return True if the identity is allowed the permission in the Return True if the identity is allowed the permission in the
@ -45,9 +40,8 @@ class AbstractAuthorizationPolicy(metaclass=abc.ABCMeta):
""" """
pass pass
@asyncio.coroutine
@abc.abstractmethod @abc.abstractmethod
def authorized_userid(self, identity): async def authorized_userid(self, identity):
"""Retrieve authorized user id. """Retrieve authorized user id.
Return the user_id of the user identified by the identity Return the user_id of the user identified by the identity

View File

@ -1,4 +1,3 @@
import asyncio
import enum import enum
from aiohttp import web from aiohttp import web
from aiohttp_security.abc import (AbstractIdentityPolicy, from aiohttp_security.abc import (AbstractIdentityPolicy,
@ -9,8 +8,7 @@ IDENTITY_KEY = 'aiohttp_security_identity_policy'
AUTZ_KEY = 'aiohttp_security_autz_policy' AUTZ_KEY = 'aiohttp_security_autz_policy'
@asyncio.coroutine async def remember(request, response, identity, **kwargs):
def remember(request, response, identity, **kwargs):
"""Remember identity into response. """Remember identity into response.
The action is performed by identity_policy.remember() The action is performed by identity_policy.remember()
@ -28,11 +26,10 @@ def remember(request, response, identity, **kwargs):
# output and rendered page we add same message to *reason* and # output and rendered page we add same message to *reason* and
# *text* arguments. # *text* arguments.
raise web.HTTPInternalServerError(reason=text, text=text) raise web.HTTPInternalServerError(reason=text, text=text)
yield from identity_policy.remember(request, response, identity, **kwargs) await identity_policy.remember(request, response, identity, **kwargs)
@asyncio.coroutine async def forget(request, response):
def forget(request, response):
"""Forget previously remembered identity. """Forget previously remembered identity.
Usually it clears cookie or server-side storage to forget user Usually it clears cookie or server-side storage to forget user
@ -46,38 +43,35 @@ def forget(request, response):
# output and rendered page we add same message to *reason* and # output and rendered page we add same message to *reason* and
# *text* arguments. # *text* arguments.
raise web.HTTPInternalServerError(reason=text, text=text) raise web.HTTPInternalServerError(reason=text, text=text)
yield from identity_policy.forget(request, response) await identity_policy.forget(request, response)
@asyncio.coroutine async def authorized_userid(request):
def authorized_userid(request):
identity_policy = request.app.get(IDENTITY_KEY) identity_policy = request.app.get(IDENTITY_KEY)
autz_policy = request.app.get(AUTZ_KEY) autz_policy = request.app.get(AUTZ_KEY)
if identity_policy is None or autz_policy is None: if identity_policy is None or autz_policy is None:
return None return None
identity = yield from identity_policy.identify(request) identity = await identity_policy.identify(request)
if identity is None: if identity is None:
return None # non-registered user has None user_id return None # non-registered user has None user_id
user_id = yield from autz_policy.authorized_userid(identity) user_id = await autz_policy.authorized_userid(identity)
return user_id return user_id
@asyncio.coroutine async def permits(request, permission, context=None):
def permits(request, permission, context=None):
assert isinstance(permission, (str, enum.Enum)), permission assert isinstance(permission, (str, enum.Enum)), permission
assert permission assert permission
identity_policy = request.app.get(IDENTITY_KEY) identity_policy = request.app.get(IDENTITY_KEY)
autz_policy = request.app.get(AUTZ_KEY) autz_policy = request.app.get(AUTZ_KEY)
if identity_policy is None or autz_policy is None: if identity_policy is None or autz_policy is None:
return True return True
identity = yield from identity_policy.identify(request) identity = await identity_policy.identify(request)
# non-registered user still may has some permissions # non-registered user still may has some permissions
access = yield from autz_policy.permits(identity, permission, context) access = await autz_policy.permits(identity, permission, context)
return access return access
@asyncio.coroutine async def is_anonymous(request):
def is_anonymous(request):
"""Check if user is anonymous. """Check if user is anonymous.
User is considered anonymous if there is not identity User is considered anonymous if there is not identity
@ -86,7 +80,7 @@ def is_anonymous(request):
identity_policy = request.app.get(IDENTITY_KEY) identity_policy = request.app.get(IDENTITY_KEY)
if identity_policy is None: if identity_policy is None:
return True return True
identity = yield from identity_policy.identify(request) identity = await identity_policy.identify(request)
if identity is None: if identity is None:
return True return True
return False return False
@ -98,9 +92,8 @@ def login_required(fn):
User is considered authorized if authorized_userid User is considered authorized if authorized_userid
returns some value. returns some value.
""" """
@asyncio.coroutine
@wraps(fn) @wraps(fn)
def wrapped(*args, **kwargs): async def wrapped(*args, **kwargs):
request = args[-1] request = args[-1]
if not isinstance(request, web.BaseRequest): if not isinstance(request, web.BaseRequest):
msg = ("Incorrect decorator usage. " msg = ("Incorrect decorator usage. "
@ -108,11 +101,11 @@ def login_required(fn):
"or `def handler(self, request)`.") "or `def handler(self, request)`.")
raise RuntimeError(msg) raise RuntimeError(msg)
userid = yield from authorized_userid(request) userid = await authorized_userid(request)
if userid is None: if userid is None:
raise web.HTTPUnauthorized raise web.HTTPUnauthorized
ret = yield from fn(*args, **kwargs) ret = await fn(*args, **kwargs)
return ret return ret
return wrapped return wrapped
@ -130,9 +123,8 @@ def has_permission(
raises HTTPForbidden. raises HTTPForbidden.
""" """
def wrapper(fn): def wrapper(fn):
@asyncio.coroutine
@wraps(fn) @wraps(fn)
def wrapped(*args, **kwargs): async def wrapped(*args, **kwargs):
request = args[-1] request = args[-1]
if not isinstance(request, web.BaseRequest): if not isinstance(request, web.BaseRequest):
msg = ("Incorrect decorator usage. " msg = ("Incorrect decorator usage. "
@ -140,14 +132,14 @@ def has_permission(
"or `def handler(self, request)`.") "or `def handler(self, request)`.")
raise RuntimeError(msg) raise RuntimeError(msg)
userid = yield from authorized_userid(request) userid = await authorized_userid(request)
if userid is None: if userid is None:
raise web.HTTPUnauthorized raise web.HTTPUnauthorized
allowed = yield from permits(request, permission, context) allowed = await permits(request, permission, context)
if not allowed: if not allowed:
raise web.HTTPForbidden raise web.HTTPForbidden
ret = yield from fn(*args, **kwargs) ret = await fn(*args, **kwargs)
return ret return ret
return wrapped return wrapped

View File

@ -5,8 +5,6 @@ more handy.
""" """
import asyncio
from .abc import AbstractIdentityPolicy from .abc import AbstractIdentityPolicy
@ -19,19 +17,16 @@ class CookiesIdentityPolicy(AbstractIdentityPolicy):
self._cookie_name = 'AIOHTTP_SECURITY' self._cookie_name = 'AIOHTTP_SECURITY'
self._max_age = 30 * 24 * 3600 self._max_age = 30 * 24 * 3600
@asyncio.coroutine async def identify(self, request):
def identify(self, request):
identity = request.cookies.get(self._cookie_name) identity = request.cookies.get(self._cookie_name)
return identity return identity
@asyncio.coroutine async def remember(self, request, response, identity, max_age=sentinel,
def remember(self, request, response, identity, max_age=sentinel, **kwargs):
**kwargs):
if max_age is sentinel: if max_age is sentinel:
max_age = self._max_age max_age = self._max_age
response.set_cookie(self._cookie_name, identity, response.set_cookie(self._cookie_name, identity,
max_age=max_age, **kwargs) max_age=max_age, **kwargs)
@asyncio.coroutine async def forget(self, request, response):
def forget(self, request, response):
response.del_cookie(self._cookie_name) response.del_cookie(self._cookie_name)

View File

@ -4,8 +4,6 @@ aiohttp_session.setup() should be called on application initialization
to configure aiohttp_session properly. to configure aiohttp_session properly.
""" """
import asyncio
try: try:
from aiohttp_session import get_session from aiohttp_session import get_session
HAS_AIOHTTP_SESSION = True HAS_AIOHTTP_SESSION = True
@ -24,17 +22,14 @@ class SessionIdentityPolicy(AbstractIdentityPolicy):
raise ImportError( raise ImportError(
'SessionIdentityPolicy requires `aiohttp_session`') 'SessionIdentityPolicy requires `aiohttp_session`')
@asyncio.coroutine async def identify(self, request):
def identify(self, request): session = await get_session(request)
session = yield from get_session(request)
return session.get(self._session_key) return session.get(self._session_key)
@asyncio.coroutine async def remember(self, request, response, identity, **kwargs):
def remember(self, request, response, identity, **kwargs): session = await get_session(request)
session = yield from get_session(request)
session[self._session_key] = identity session[self._session_key] = identity
@asyncio.coroutine async def forget(self, request, response):
def forget(self, request, response): session = await get_session(request)
session = yield from get_session(request)
session.pop(self._session_key, None) session.pop(self._session_key, None)

View File

@ -1,5 +1,3 @@
import asyncio
import sqlalchemy as sa import sqlalchemy as sa
from aiohttp_security.abc import AbstractAuthorizationPolicy from aiohttp_security.abc import AbstractAuthorizationPolicy
@ -12,29 +10,27 @@ class DBAuthorizationPolicy(AbstractAuthorizationPolicy):
def __init__(self, dbengine): def __init__(self, dbengine):
self.dbengine = dbengine self.dbengine = dbengine
@asyncio.coroutine
def authorized_userid(self, identity): def authorized_userid(self, identity):
with (yield from self.dbengine) as conn: async with self.dbengine as conn:
where = sa.and_(db.users.c.login == identity, where = sa.and_(db.users.c.login == identity,
sa.not_(db.users.c.disabled)) sa.not_(db.users.c.disabled))
query = db.users.count().where(where) query = db.users.count().where(where)
ret = yield from conn.scalar(query) ret = await conn.scalar(query)
if ret: if ret:
return identity return identity
else: else:
return None return None
@asyncio.coroutine async def permits(self, identity, permission, context=None):
def permits(self, identity, permission, context=None):
if identity is None: if identity is None:
return False return False
with (yield from self.dbengine) as conn: async with self.dbengine as conn:
where = sa.and_(db.users.c.login == identity, where = sa.and_(db.users.c.login == identity,
sa.not_(db.users.c.disabled)) sa.not_(db.users.c.disabled))
query = db.users.select().where(where) query = db.users.select().where(where)
ret = yield from conn.execute(query) ret = await conn.execute(query)
user = yield from ret.fetchone() user = await ret.fetchone()
if user is not None: if user is not None:
user_id = user[0] user_id = user[0]
is_superuser = user[3] is_superuser = user[3]
@ -43,8 +39,8 @@ class DBAuthorizationPolicy(AbstractAuthorizationPolicy):
where = db.permissions.c.user_id == user_id where = db.permissions.c.user_id == user_id
query = db.permissions.select().where(where) query = db.permissions.select().where(where)
ret = yield from conn.execute(query) ret = await conn.execute(query)
result = yield from ret.fetchall() result = await ret.fetchall()
if ret is not None: if ret is not None:
for record in result: for record in result:
if record.perm_name == permission: if record.perm_name == permission:
@ -53,14 +49,13 @@ class DBAuthorizationPolicy(AbstractAuthorizationPolicy):
return False return False
@asyncio.coroutine async def check_credentials(db_engine, username, password):
def check_credentials(db_engine, username, password): async with db_engine as conn:
with (yield from db_engine) as conn:
where = sa.and_(db.users.c.login == username, where = sa.and_(db.users.c.login == username,
sa.not_(db.users.c.disabled)) sa.not_(db.users.c.disabled))
query = db.users.select().where(where) query = db.users.select().where(where)
ret = yield from conn.execute(query) ret = await conn.execute(query)
user = yield from ret.fetchone() user = await ret.fetchone()
if user is not None: if user is not None:
hash = user[2] hash = user[2]
return sha256_crypt.verify(password, hash) return sha256_crypt.verify(password, hash)

View File

@ -1,4 +1,3 @@
import asyncio
import functools import functools
from aiohttp import web from aiohttp import web
@ -10,14 +9,13 @@ from .db_auth import check_credentials
def require(permission): def require(permission):
def wrapper(f): def wrapper(f):
@asyncio.coroutine
@functools.wraps(f) @functools.wraps(f)
def wrapped(self, request): async def wrapped(self, request):
has_perm = yield from permits(request, permission) has_perm = await permits(request, permission)
if not has_perm: if not has_perm:
message = 'User has no permission {}'.format(permission) message = 'User has no permission {}'.format(permission)
raise web.HTTPForbidden(body=message.encode()) raise web.HTTPForbidden(body=message.encode())
return (yield from f(self, request)) return await f(self, request)
return wrapped return wrapped
return wrapper return wrapper
@ -40,9 +38,8 @@ class Web(object):
</body> </body>
""" """
@asyncio.coroutine async def index(self, request):
def index(self, request): username = await authorized_userid(request)
username = yield from authorized_userid(request)
if username: if username:
template = self.index_template.format( template = self.index_template.format(
message='Hello, {username}!'.format(username=username)) message='Hello, {username}!'.format(username=username))
@ -51,37 +48,33 @@ class Web(object):
response = web.Response(body=template.encode()) response = web.Response(body=template.encode())
return response return response
@asyncio.coroutine async def login(self, request):
def login(self, request):
response = web.HTTPFound('/') response = web.HTTPFound('/')
form = yield from request.post() form = await request.post()
login = form.get('login') login = form.get('login')
password = form.get('password') password = form.get('password')
db_engine = request.app.db_engine db_engine = request.app.db_engine
if (yield from check_credentials(db_engine, login, password)): if await check_credentials(db_engine, login, password):
yield from remember(request, response, login) await remember(request, response, login)
return response return response
return web.HTTPUnauthorized( return web.HTTPUnauthorized(
body=b'Invalid username/password combination') body=b'Invalid username/password combination')
@require('public') @require('public')
@asyncio.coroutine async def logout(self, request):
def logout(self, request):
response = web.Response(body=b'You have been logged out') response = web.Response(body=b'You have been logged out')
yield from forget(request, response) await forget(request, response)
return response return response
@require('public') @require('public')
@asyncio.coroutine async def internal_page(self, request):
def internal_page(self, request):
response = web.Response( response = web.Response(
body=b'This page is visible for all registered users') body=b'This page is visible for all registered users')
return response return response
@require('protected') @require('protected')
@asyncio.coroutine async def protected_page(self, request):
def protected_page(self, request):
response = web.Response(body=b'You are on protected page') response = web.Response(body=b'You are on protected page')
return response return response

View File

@ -13,10 +13,9 @@ from demo.db_auth import DBAuthorizationPolicy
from demo.handlers import Web from demo.handlers import Web
@asyncio.coroutine
def init(loop): def init(loop):
redis_pool = yield from create_pool(('localhost', 6379)) redis_pool = await create_pool(('localhost', 6379))
db_engine = yield from create_engine(user='aiohttp_security', db_engine = await create_engine(user='aiohttp_security',
password='aiohttp_security', password='aiohttp_security',
database='aiohttp_security', database='aiohttp_security',
host='127.0.0.1') host='127.0.0.1')
@ -31,21 +30,20 @@ def init(loop):
web_handlers.configure(app) web_handlers.configure(app)
handler = app.make_handler() handler = app.make_handler()
srv = yield from loop.create_server(handler, '127.0.0.1', 8080) srv = await loop.create_server(handler, '127.0.0.1', 8080)
print('Server started at http://127.0.0.1:8080') print('Server started at http://127.0.0.1:8080')
return srv, app, handler return srv, app, handler
@asyncio.coroutine async def finalize(srv, app, handler):
def finalize(srv, app, handler):
sock = srv.sockets[0] sock = srv.sockets[0]
app.loop.remove_reader(sock.fileno()) app.loop.remove_reader(sock.fileno())
sock.close() sock.close()
yield from handler.finish_connections(1.0) await handler.finish_connections(1.0)
srv.close() srv.close()
yield from srv.wait_closed() await srv.wait_closed()
yield from app.finish() await app.finish()
def main(): def main():

View File

@ -1,4 +1,3 @@
import asyncio
import functools import functools
from textwrap import dedent from textwrap import dedent
@ -11,14 +10,13 @@ from .authz import check_credentials
def require(permission): def require(permission):
def wrapper(f): def wrapper(f):
@asyncio.coroutine
@functools.wraps(f) @functools.wraps(f)
def wrapped(request): async def wrapped(request):
has_perm = yield from permits(request, permission) has_perm = await permits(request, permission)
if not has_perm: if not has_perm:
message = 'User has no permission {}'.format(permission) message = 'User has no permission {}'.format(permission)
raise web.HTTPForbidden(body=message.encode()) raise web.HTTPForbidden(body=message.encode())
return (yield from f(request)) return await f(request)
return wrapped return wrapped
return wrapper return wrapper

View File

@ -10,16 +10,14 @@ Simple example::
import asyncio import asyncio
from aiohttp import web from aiohttp import web
@asyncio.coroutine async def root_handler(request):
def root_handler(request):
text = "Alive and kicking!" text = "Alive and kicking!"
return web.Response(body=text.encode('utf-8')) return web.Response(body=text.encode('utf-8'))
# option 2: auth at a higher level? # option 2: auth at a higher level?
# set user_id and allowed in the wsgi handler # set user_id and allowed in the wsgi handler
@protect('view_user') @protect('view_user')
@asyncio.coroutine async def user_handler(request):
def user_handler(request):
name = request.match_info.get('name', "Anonymous") name = request.match_info.get('name', "Anonymous")
text = "Hello, " + name text = "Hello, " + name
return web.Response(body=text.encode('utf-8')) return web.Response(body=text.encode('utf-8'))
@ -27,14 +25,12 @@ Simple example::
# option 3: super low # option 3: super low
# wsgi doesn't do anything # wsgi doesn't do anything
@asyncio.coroutine async def user_update_handler(request):
def user_update_handler(request):
# identity, asked_permission # identity, asked_permission
user_id = yield from identity_policy.identify(request) user_id = await identity_policy.identify(request)
identity = yield from auth_policy.authorized_user_id(user_id) identity = await auth_policy.authorized_user_id(user_id)
allowed = yield from request.auth_policy.permits( allowed = await request.auth_policy.permits(identity,
identity, asked_permission asked_permission)
)
if not allowed: if not allowed:
# how is this pluggable as well? # how is this pluggable as well?
# ? return NotAllowedStream() # ? return NotAllowedStream()
@ -42,8 +38,7 @@ Simple example::
update_user() update_user()
@asyncio.coroutine async def init(loop):
def init(loop):
# set up identity and auth # set up identity and auth
auth_policy = DictionaryAuthorizationPolicy({'me': ('view_user',), auth_policy = DictionaryAuthorizationPolicy({'me': ('view_user',),
'you': ('view_user', 'you': ('view_user',
@ -60,7 +55,7 @@ Simple example::
app.router.add_route('GET', '/{user}/edit', user_update_handler) app.router.add_route('GET', '/{user}/edit', user_update_handler)
# get it started # get it started
srv = yield from loop.create_server(app.make_handler(), srv = await loop.create_server(app.make_handler(),
'127.0.0.1', 8080) '127.0.0.1', 8080)
print("Server started at http://127.0.0.1:8080") print("Server started at http://127.0.0.1:8080")
return srv return srv

View File

@ -67,13 +67,12 @@ In our example we will lookup database by user login and if present return
this identity:: this identity::
@asyncio.coroutine async def authorized_userid(self, identity):
def authorized_userid(self, identity): async with self.dbengine as conn:
with (yield from self.dbengine) as conn:
where = sa.and_(db.users.c.login == identity, where = sa.and_(db.users.c.login == identity,
sa.not_(db.users.c.disabled)) sa.not_(db.users.c.disabled))
query = db.users.count().where(where) query = db.users.count().where(where)
ret = yield from conn.scalar(query) ret = await conn.scalar(query)
if ret: if ret:
return identity return identity
else: else:
@ -84,17 +83,16 @@ For permission check we will fetch the user first, check if he is superuser
(all permissions are allowed), otherwise check if permission is explicitly set (all permissions are allowed), otherwise check if permission is explicitly set
for that user:: for that user::
@asyncio.coroutine async def permits(self, identity, permission, context=None):
def permits(self, identity, permission, context=None):
if identity is None: if identity is None:
return False return False
with (yield from self.dbengine) as conn: async with self.dbengine as conn:
where = sa.and_(db.users.c.login == identity, where = sa.and_(db.users.c.login == identity,
sa.not_(db.users.c.disabled)) sa.not_(db.users.c.disabled))
query = db.users.select().where(where) query = db.users.select().where(where)
ret = yield from conn.execute(query) ret = await conn.execute(query)
user = yield from ret.fetchone() user = await ret.fetchone()
if user is not None: if user is not None:
user_id = user[0] user_id = user[0]
is_superuser = user[4] is_superuser = user[4]
@ -103,8 +101,8 @@ for that user::
where = db.permissions.c.user_id == user_id where = db.permissions.c.user_id == user_id
query = db.permissions.select().where(where) query = db.permissions.select().where(where)
ret = yield from conn.execute(query) ret = await conn.execute(query)
result = yield from ret.fetchall() result = await ret.fetchall()
if ret is not None: if ret is not None:
for record in result: for record in result:
if record.perm_name == permission: if record.perm_name == permission:
@ -127,13 +125,12 @@ Once we have all the code in place we can install it for our application::
from .db_auth import DBAuthorizationPolicy from .db_auth import DBAuthorizationPolicy
@asyncio.coroutine async def init(loop):
def init(loop): redis_pool = await create_pool(('localhost', 6379))
redis_pool = yield from create_pool(('localhost', 6379)) dbengine = await create_engine(user='aiohttp_security',
dbengine = yield from create_engine(user='aiohttp_security', password='aiohttp_security',
password='aiohttp_security', database='aiohttp_security',
database='aiohttp_security', host='127.0.0.1')
host='127.0.0.1')
app = web.Application(loop=loop) app = web.Application(loop=loop)
setup_session(app, RedisStorage(redis_pool)) setup_session(app, RedisStorage(redis_pool))
setup_security(app, setup_security(app,
@ -148,14 +145,13 @@ help to do that::
def require(permission): def require(permission):
def wrapper(f): def wrapper(f):
@asyncio.coroutine
@functools.wraps(f) @functools.wraps(f)
def wrapped(self, request): async def wrapped(self, request):
has_perm = yield from permits(request, permission) has_perm = await permits(request, permission)
if not has_perm: if not has_perm:
message = 'User has no permission {}'.format(permission) message = 'User has no permission {}'.format(permission)
raise web.HTTPForbidden(body=message.encode()) raise web.HTTPForbidden(body=message.encode())
return (yield from f(self, request)) return await f(self, request)
return wrapped return wrapped
return wrapper return wrapper
@ -164,8 +160,7 @@ For each view you need to protect just apply the decorator on it::
class Web: class Web:
@require('protected') @require('protected')
@asyncio.coroutine async def protected_page(self, request):
def protected_page(self, request):
response = web.Response(body=b'You are on protected page') response = web.Response(body=b'You are on protected page')
return response return response
@ -187,14 +182,13 @@ function may do what you trying to accomplish::
from passlib.hash import sha256_crypt from passlib.hash import sha256_crypt
@asyncio.coroutine async def check_credentials(db_engine, username, password):
def check_credentials(db_engine, username, password): async with db_engine as conn:
with (yield from db_engine) as conn:
where = sa.and_(db.users.c.login == username, where = sa.and_(db.users.c.login == username,
sa.not_(db.users.c.disabled)) sa.not_(db.users.c.disabled))
query = db.users.select().where(where) query = db.users.select().where(where)
ret = yield from conn.execute(query) ret = await conn.execute(query)
user = yield from ret.fetchone() user = await ret.fetchone()
if user is not None: if user is not None:
hash = user[2] hash = user[2]
return sha256_crypt.verify(password, hash) return sha256_crypt.verify(password, hash)

View File

@ -1,5 +1,3 @@
import asyncio
from aiohttp import web from aiohttp import web
from aiohttp_security import (remember, forget, from aiohttp_security import (remember, forget,
AbstractAuthorizationPolicy) AbstractAuthorizationPolicy)
@ -10,47 +8,39 @@ from aiohttp_security.api import IDENTITY_KEY
class Autz(AbstractAuthorizationPolicy): class Autz(AbstractAuthorizationPolicy):
@asyncio.coroutine async def permits(self, identity, permission, context=None):
def permits(self, identity, permission, context=None):
pass pass
@asyncio.coroutine async def authorized_userid(self, identity):
def authorized_userid(self, identity):
pass pass
@asyncio.coroutine async def test_remember(loop, test_client):
def test_remember(loop, test_client):
@asyncio.coroutine async def handler(request):
def handler(request):
response = web.Response() response = web.Response()
yield from remember(request, response, 'Andrew') await remember(request, response, 'Andrew')
return response return response
app = web.Application(loop=loop) app = web.Application(loop=loop)
_setup(app, CookiesIdentityPolicy(), Autz()) _setup(app, CookiesIdentityPolicy(), Autz())
app.router.add_route('GET', '/', handler) app.router.add_route('GET', '/', handler)
client = yield from test_client(app) client = await test_client(app)
resp = yield from client.get('/') resp = await client.get('/')
assert 200 == resp.status assert 200 == resp.status
assert 'Andrew' == resp.cookies['AIOHTTP_SECURITY'].value assert 'Andrew' == resp.cookies['AIOHTTP_SECURITY'].value
yield from resp.release()
@asyncio.coroutine async def test_identify(loop, test_client):
def test_identify(loop, test_client):
@asyncio.coroutine async def create(request):
def create(request):
response = web.Response() response = web.Response()
yield from remember(request, response, 'Andrew') await remember(request, response, 'Andrew')
return response return response
@asyncio.coroutine async def check(request):
def check(request):
policy = request.app[IDENTITY_KEY] policy = request.app[IDENTITY_KEY]
user_id = yield from policy.identify(request) user_id = await policy.identify(request)
assert 'Andrew' == user_id assert 'Andrew' == user_id
return web.Response() return web.Response()
@ -58,32 +48,27 @@ def test_identify(loop, test_client):
_setup(app, CookiesIdentityPolicy(), Autz()) _setup(app, CookiesIdentityPolicy(), Autz())
app.router.add_route('GET', '/', check) app.router.add_route('GET', '/', check)
app.router.add_route('POST', '/', create) app.router.add_route('POST', '/', create)
client = yield from test_client(app) client = await test_client(app)
resp = yield from client.post('/') resp = await client.post('/')
assert 200 == resp.status assert 200 == resp.status
yield from resp.release() await resp.release()
resp = yield from client.get('/') resp = await client.get('/')
assert 200 == resp.status assert 200 == resp.status
yield from resp.release()
@asyncio.coroutine async def test_forget(loop, test_client):
def test_forget(loop, test_client):
@asyncio.coroutine async def index(request):
def index(request):
return web.Response() return web.Response()
@asyncio.coroutine async def login(request):
def login(request):
response = web.HTTPFound(location='/') response = web.HTTPFound(location='/')
yield from remember(request, response, 'Andrew') await remember(request, response, 'Andrew')
return response return response
@asyncio.coroutine async def logout(request):
def logout(request):
response = web.HTTPFound(location='/') response = web.HTTPFound(location='/')
yield from forget(request, response) await forget(request, response)
return response return response
app = web.Application(loop=loop) app = web.Application(loop=loop)
@ -91,18 +76,17 @@ def test_forget(loop, test_client):
app.router.add_route('GET', '/', index) app.router.add_route('GET', '/', index)
app.router.add_route('POST', '/login', login) app.router.add_route('POST', '/login', login)
app.router.add_route('POST', '/logout', logout) app.router.add_route('POST', '/logout', logout)
client = yield from test_client(app) client = await test_client(app)
resp = yield from client.post('/login') resp = await client.post('/login')
assert 200 == resp.status assert 200 == resp.status
assert str(resp.url).endswith('/') assert str(resp.url).endswith('/')
cookies = client.session.cookie_jar.filter_cookies( cookies = client.session.cookie_jar.filter_cookies(
client.make_url('/')) client.make_url('/'))
assert 'Andrew' == cookies['AIOHTTP_SECURITY'].value assert 'Andrew' == cookies['AIOHTTP_SECURITY'].value
yield from resp.release()
resp = yield from client.post('/logout') resp = await client.post('/logout')
assert 200 == resp.status assert 200 == resp.status
assert str(resp.url).endswith('/') assert str(resp.url).endswith('/')
cookies = client.session.cookie_jar.filter_cookies( cookies = client.session.cookie_jar.filter_cookies(
client.make_url('/')) client.make_url('/'))
assert 'AIOHTTP_SECURITY' not in cookies assert 'AIOHTTP_SECURITY' not in cookies
yield from resp.release()

View File

@ -1,4 +1,3 @@
import asyncio
import enum import enum
from aiohttp import web from aiohttp import web
@ -11,33 +10,28 @@ from aiohttp_security.cookies_identity import CookiesIdentityPolicy
class Autz(AbstractAuthorizationPolicy): class Autz(AbstractAuthorizationPolicy):
@asyncio.coroutine async def permits(self, identity, permission, context=None):
def permits(self, identity, permission, context=None):
if identity == 'UserID': if identity == 'UserID':
return permission in {'read', 'write'} return permission in {'read', 'write'}
else: else:
return False return False
@asyncio.coroutine async def authorized_userid(self, identity):
def authorized_userid(self, identity):
if identity == 'UserID': if identity == 'UserID':
return 'Andrew' return 'Andrew'
else: else:
return None return None
@asyncio.coroutine async def test_authorized_userid(loop, test_client):
def test_authorized_userid(loop, test_client):
@asyncio.coroutine async def login(request):
def login(request):
response = web.HTTPFound(location='/') response = web.HTTPFound(location='/')
yield from remember(request, response, 'UserID') await remember(request, response, 'UserID')
return response return response
@asyncio.coroutine async def check(request):
def check(request): userid = await authorized_userid(request)
userid = yield from authorized_userid(request)
assert 'Andrew' == userid assert 'Andrew' == userid
return web.Response(text=userid) return web.Response(text=userid)
@ -45,36 +39,31 @@ def test_authorized_userid(loop, test_client):
_setup(app, CookiesIdentityPolicy(), Autz()) _setup(app, CookiesIdentityPolicy(), Autz())
app.router.add_route('GET', '/', check) app.router.add_route('GET', '/', check)
app.router.add_route('POST', '/login', login) app.router.add_route('POST', '/login', login)
client = yield from test_client(app) client = await test_client(app)
resp = yield from client.post('/login') resp = await client.post('/login')
assert 200 == resp.status assert 200 == resp.status
txt = yield from resp.text() txt = await resp.text()
assert 'Andrew' == txt assert 'Andrew' == txt
yield from resp.release()
@asyncio.coroutine async def test_authorized_userid_not_authorized(loop, test_client):
def test_authorized_userid_not_authorized(loop, test_client):
@asyncio.coroutine async def check(request):
def check(request): userid = await authorized_userid(request)
userid = yield from authorized_userid(request)
assert userid is None assert userid is None
return web.Response() return web.Response()
app = web.Application(loop=loop) app = web.Application(loop=loop)
_setup(app, CookiesIdentityPolicy(), Autz()) _setup(app, CookiesIdentityPolicy(), Autz())
app.router.add_route('GET', '/', check) app.router.add_route('GET', '/', check)
client = yield from test_client(app) client = await test_client(app)
resp = yield from client.get('/') resp = await client.get('/')
assert 200 == resp.status assert 200 == resp.status
yield from resp.release()
@asyncio.coroutine async def test_permits_enum_permission(loop, test_client):
def test_permits_enum_permission(loop, test_client):
class Permission(enum.Enum): class Permission(enum.Enum):
READ = '101' READ = '101'
WRITE = '102' WRITE = '102'
@ -82,33 +71,29 @@ def test_permits_enum_permission(loop, test_client):
class Autz(AbstractAuthorizationPolicy): class Autz(AbstractAuthorizationPolicy):
@asyncio.coroutine async def permits(self, identity, permission, context=None):
def permits(self, identity, permission, context=None):
if identity == 'UserID': if identity == 'UserID':
return permission in {Permission.READ, Permission.WRITE} return permission in {Permission.READ, Permission.WRITE}
else: else:
return False return False
@asyncio.coroutine async def authorized_userid(self, identity):
def authorized_userid(self, identity):
if identity == 'UserID': if identity == 'UserID':
return 'Andrew' return 'Andrew'
else: else:
return None return None
@asyncio.coroutine async def login(request):
def login(request):
response = web.HTTPFound(location='/') response = web.HTTPFound(location='/')
yield from remember(request, response, 'UserID') await remember(request, response, 'UserID')
return response return response
@asyncio.coroutine async def check(request):
def check(request): ret = await permits(request, Permission.READ)
ret = yield from permits(request, Permission.READ)
assert ret assert ret
ret = yield from permits(request, Permission.WRITE) ret = await permits(request, Permission.WRITE)
assert ret assert ret
ret = yield from permits(request, Permission.UNKNOWN) ret = await permits(request, Permission.UNKNOWN)
assert not ret assert not ret
return web.Response() return web.Response()
@ -116,54 +101,46 @@ def test_permits_enum_permission(loop, test_client):
_setup(app, CookiesIdentityPolicy(), Autz()) _setup(app, CookiesIdentityPolicy(), Autz())
app.router.add_route('GET', '/', check) app.router.add_route('GET', '/', check)
app.router.add_route('POST', '/login', login) app.router.add_route('POST', '/login', login)
client = yield from test_client(app) client = await test_client(app)
resp = yield from client.post('/login') resp = await client.post('/login')
assert 200 == resp.status assert 200 == resp.status
yield from resp.release()
@asyncio.coroutine async def test_permits_unauthorized(loop, test_client):
def test_permits_unauthorized(loop, test_client):
@asyncio.coroutine async def check(request):
def check(request): ret = await permits(request, 'read')
ret = yield from permits(request, 'read')
assert not ret assert not ret
ret = yield from permits(request, 'write') ret = await permits(request, 'write')
assert not ret assert not ret
ret = yield from permits(request, 'unknown') ret = await permits(request, 'unknown')
assert not ret assert not ret
return web.Response() return web.Response()
app = web.Application(loop=loop) app = web.Application(loop=loop)
_setup(app, CookiesIdentityPolicy(), Autz()) _setup(app, CookiesIdentityPolicy(), Autz())
app.router.add_route('GET', '/', check) app.router.add_route('GET', '/', check)
client = yield from test_client(app) client = await test_client(app)
resp = yield from client.get('/') resp = await client.get('/')
assert 200 == resp.status assert 200 == resp.status
yield from resp.release()
@asyncio.coroutine async def test_is_anonymous(loop, test_client):
def test_is_anonymous(loop, test_client):
@asyncio.coroutine async def index(request):
def index(request): is_anon = await is_anonymous(request)
is_anon = yield from is_anonymous(request)
if is_anon: if is_anon:
return web.HTTPUnauthorized() return web.HTTPUnauthorized()
return web.HTTPOk() return web.HTTPOk()
@asyncio.coroutine async def login(request):
def login(request):
response = web.HTTPFound(location='/') response = web.HTTPFound(location='/')
yield from remember(request, response, 'UserID') await remember(request, response, 'UserID')
return response return response
@asyncio.coroutine async def logout(request):
def logout(request):
response = web.HTTPFound(location='/') response = web.HTTPFound(location='/')
yield from forget(request, response) await forget(request, response)
return response return response
app = web.Application(loop=loop) app = web.Application(loop=loop)
@ -171,36 +148,32 @@ def test_is_anonymous(loop, test_client):
app.router.add_route('GET', '/', index) app.router.add_route('GET', '/', index)
app.router.add_route('POST', '/login', login) app.router.add_route('POST', '/login', login)
app.router.add_route('POST', '/logout', logout) app.router.add_route('POST', '/logout', logout)
client = yield from test_client(app) client = await test_client(app)
resp = yield from client.get('/') resp = await client.get('/')
assert web.HTTPUnauthorized.status_code == resp.status assert web.HTTPUnauthorized.status_code == resp.status
yield from client.post('/login') await client.post('/login')
resp = yield from client.get('/') resp = await client.get('/')
assert web.HTTPOk.status_code == resp.status assert web.HTTPOk.status_code == resp.status
yield from client.post('/logout') await client.post('/logout')
resp = yield from client.get('/') resp = await client.get('/')
assert web.HTTPUnauthorized.status_code == resp.status assert web.HTTPUnauthorized.status_code == resp.status
@asyncio.coroutine async def test_login_required(loop, test_client):
def test_login_required(loop, test_client):
@login_required @login_required
@asyncio.coroutine async def index(request):
def index(request):
return web.HTTPOk() return web.HTTPOk()
@asyncio.coroutine async def login(request):
def login(request):
response = web.HTTPFound(location='/') response = web.HTTPFound(location='/')
yield from remember(request, response, 'UserID') await remember(request, response, 'UserID')
return response return response
@asyncio.coroutine async def logout(request):
def logout(request):
response = web.HTTPFound(location='/') response = web.HTTPFound(location='/')
yield from forget(request, response) await forget(request, response)
return response return response
app = web.Application(loop=loop) app = web.Application(loop=loop)
@ -208,47 +181,41 @@ def test_login_required(loop, test_client):
app.router.add_route('GET', '/', index) app.router.add_route('GET', '/', index)
app.router.add_route('POST', '/login', login) app.router.add_route('POST', '/login', login)
app.router.add_route('POST', '/logout', logout) app.router.add_route('POST', '/logout', logout)
client = yield from test_client(app) client = await test_client(app)
resp = yield from client.get('/') resp = await client.get('/')
assert web.HTTPUnauthorized.status_code == resp.status assert web.HTTPUnauthorized.status_code == resp.status
yield from client.post('/login') await client.post('/login')
resp = yield from client.get('/') resp = await client.get('/')
assert web.HTTPOk.status_code == resp.status assert web.HTTPOk.status_code == resp.status
yield from client.post('/logout') await client.post('/logout')
resp = yield from client.get('/') resp = await client.get('/')
assert web.HTTPUnauthorized.status_code == resp.status assert web.HTTPUnauthorized.status_code == resp.status
@asyncio.coroutine async def test_has_permission(loop, test_client):
def test_has_permission(loop, test_client):
@has_permission('read') @has_permission('read')
@asyncio.coroutine async def index_read(request):
def index_read(request):
return web.HTTPOk() return web.HTTPOk()
@has_permission('write') @has_permission('write')
@asyncio.coroutine async def index_write(request):
def index_write(request):
return web.HTTPOk() return web.HTTPOk()
@has_permission('forbid') @has_permission('forbid')
@asyncio.coroutine async def index_forbid(request):
def index_forbid(request):
return web.HTTPOk() return web.HTTPOk()
@asyncio.coroutine async def login(request):
def login(request):
response = web.HTTPFound(location='/') response = web.HTTPFound(location='/')
yield from remember(request, response, 'UserID') await remember(request, response, 'UserID')
return response return response
@asyncio.coroutine async def logout(request):
def logout(request):
response = web.HTTPFound(location='/') response = web.HTTPFound(location='/')
yield from forget(request, response) await forget(request, response)
return response return response
app = web.Application(loop=loop) app = web.Application(loop=loop)
@ -258,27 +225,27 @@ def test_has_permission(loop, test_client):
app.router.add_route('GET', '/permission/forbid', index_forbid) app.router.add_route('GET', '/permission/forbid', index_forbid)
app.router.add_route('POST', '/login', login) app.router.add_route('POST', '/login', login)
app.router.add_route('POST', '/logout', logout) app.router.add_route('POST', '/logout', logout)
client = yield from test_client(app) client = await test_client(app)
resp = yield from client.get('/permission/read') resp = await client.get('/permission/read')
assert web.HTTPUnauthorized.status_code == resp.status assert web.HTTPUnauthorized.status_code == resp.status
resp = yield from client.get('/permission/write') resp = await client.get('/permission/write')
assert web.HTTPUnauthorized.status_code == resp.status assert web.HTTPUnauthorized.status_code == resp.status
resp = yield from client.get('/permission/forbid') resp = await client.get('/permission/forbid')
assert web.HTTPUnauthorized.status_code == resp.status assert web.HTTPUnauthorized.status_code == resp.status
yield from client.post('/login') await client.post('/login')
resp = yield from client.get('/permission/read') resp = await client.get('/permission/read')
assert web.HTTPOk.status_code == resp.status assert web.HTTPOk.status_code == resp.status
resp = yield from client.get('/permission/write') resp = await client.get('/permission/write')
assert web.HTTPOk.status_code == resp.status assert web.HTTPOk.status_code == resp.status
resp = yield from client.get('/permission/forbid') resp = await client.get('/permission/forbid')
assert web.HTTPForbidden.status_code == resp.status assert web.HTTPForbidden.status_code == resp.status
yield from client.post('/logout') await client.post('/logout')
resp = yield from client.get('/permission/read') resp = await client.get('/permission/read')
assert web.HTTPUnauthorized.status_code == resp.status assert web.HTTPUnauthorized.status_code == resp.status
resp = yield from client.get('/permission/write') resp = await client.get('/permission/write')
assert web.HTTPUnauthorized.status_code == resp.status assert web.HTTPUnauthorized.status_code == resp.status
resp = yield from client.get('/permission/forbid') resp = await client.get('/permission/forbid')
assert web.HTTPUnauthorized.status_code == resp.status assert web.HTTPUnauthorized.status_code == resp.status

View File

@ -1,42 +1,34 @@
import asyncio
from aiohttp import web from aiohttp import web
from aiohttp_security import authorized_userid, permits from aiohttp_security import authorized_userid, permits
@asyncio.coroutine async def test_authorized_userid(loop, test_client):
def test_authorized_userid(loop, test_client):
@asyncio.coroutine async def check(request):
def check(request): userid = await authorized_userid(request)
userid = yield from authorized_userid(request)
assert userid is None assert userid is None
return web.Response() return web.Response()
app = web.Application(loop=loop) app = web.Application(loop=loop)
app.router.add_route('GET', '/', check) app.router.add_route('GET', '/', check)
client = yield from test_client(app) client = await test_client(app)
resp = yield from client.get('/') resp = await client.get('/')
assert 200 == resp.status assert 200 == resp.status
yield from resp.release()
@asyncio.coroutine async def test_permits(loop, test_client):
def test_permits(loop, test_client):
@asyncio.coroutine async def check(request):
def check(request): ret = await permits(request, 'read')
ret = yield from permits(request, 'read')
assert ret assert ret
ret = yield from permits(request, 'write') ret = await permits(request, 'write')
assert ret assert ret
ret = yield from permits(request, 'unknown') ret = await permits(request, 'unknown')
assert ret assert ret
return web.Response() return web.Response()
app = web.Application(loop=loop) app = web.Application(loop=loop)
app.router.add_route('GET', '/', check) app.router.add_route('GET', '/', check)
client = yield from test_client(app) client = await test_client(app)
resp = yield from client.get('/') resp = await client.get('/')
assert 200 == resp.status assert 200 == resp.status
yield from resp.release()

View File

@ -1,42 +1,34 @@
import asyncio
from aiohttp import web from aiohttp import web
from aiohttp_security import remember, forget from aiohttp_security import remember, forget
@asyncio.coroutine async def test_remember(loop, test_client):
def test_remember(loop, test_client):
@asyncio.coroutine async def do_remember(request):
def do_remember(request):
response = web.Response() response = web.Response()
yield from remember(request, response, 'Andrew') await remember(request, response, 'Andrew')
app = web.Application(loop=loop) app = web.Application(loop=loop)
app.router.add_route('POST', '/', do_remember) app.router.add_route('POST', '/', do_remember)
client = yield from test_client(app) client = await test_client(app)
resp = yield from client.post('/') resp = await client.post('/')
assert 500 == resp.status assert 500 == resp.status
assert (('Security subsystem is not initialized, ' assert (('Security subsystem is not initialized, '
'call aiohttp_security.setup(...) first') == 'call aiohttp_security.setup(...) first') ==
resp.reason) resp.reason)
yield from resp.release()
@asyncio.coroutine async def test_forget(loop, test_client):
def test_forget(loop, test_client):
@asyncio.coroutine async def do_forget(request):
def do_forget(request):
response = web.Response() response = web.Response()
yield from forget(request, response) await forget(request, response)
app = web.Application(loop=loop) app = web.Application(loop=loop)
app.router.add_route('POST', '/', do_forget) app.router.add_route('POST', '/', do_forget)
client = yield from test_client(app) client = await test_client(app)
resp = yield from client.post('/') resp = await client.post('/')
assert 500 == resp.status assert 500 == resp.status
assert (('Security subsystem is not initialized, ' assert (('Security subsystem is not initialized, '
'call aiohttp_security.setup(...) first') == 'call aiohttp_security.setup(...) first') ==
resp.reason) resp.reason)
yield from resp.release()

View File

@ -1,4 +1,3 @@
import asyncio
import pytest import pytest
from aiohttp import web from aiohttp import web
@ -13,12 +12,10 @@ from aiohttp_session import setup as setup_session
class Autz(AbstractAuthorizationPolicy): class Autz(AbstractAuthorizationPolicy):
@asyncio.coroutine async def permits(self, identity, permission, context=None):
def permits(self, identity, permission, context=None):
pass pass
@asyncio.coroutine async def authorized_userid(self, identity):
def authorized_userid(self, identity):
pass pass
@ -30,81 +27,67 @@ def make_app(loop):
return app return app
@asyncio.coroutine async def test_remember(make_app, test_client):
def test_remember(make_app, test_client):
@asyncio.coroutine async def handler(request):
def handler(request):
response = web.Response() response = web.Response()
yield from remember(request, response, 'Andrew') await remember(request, response, 'Andrew')
return response return response
@asyncio.coroutine async def check(request):
def check(request): session = await get_session(request)
session = yield from get_session(request)
assert session['AIOHTTP_SECURITY'] == 'Andrew' assert session['AIOHTTP_SECURITY'] == 'Andrew'
return web.HTTPOk() return web.HTTPOk()
app = make_app() app = make_app()
app.router.add_route('GET', '/', handler) app.router.add_route('GET', '/', handler)
app.router.add_route('GET', '/check', check) app.router.add_route('GET', '/check', check)
client = yield from test_client(app) client = await test_client(app)
resp = yield from client.get('/') resp = await client.get('/')
assert 200 == resp.status assert 200 == resp.status
yield from resp.release()
resp = yield from client.get('/check') resp = await client.get('/check')
assert 200 == resp.status assert 200 == resp.status
yield from resp.release()
@asyncio.coroutine async def test_identify(make_app, test_client):
def test_identify(make_app, test_client):
@asyncio.coroutine async def create(request):
def create(request):
response = web.Response() response = web.Response()
yield from remember(request, response, 'Andrew') await remember(request, response, 'Andrew')
return response return response
@asyncio.coroutine async def check(request):
def check(request):
policy = request.app[IDENTITY_KEY] policy = request.app[IDENTITY_KEY]
user_id = yield from policy.identify(request) user_id = await policy.identify(request)
assert 'Andrew' == user_id assert 'Andrew' == user_id
return web.Response() return web.Response()
app = make_app() app = make_app()
app.router.add_route('GET', '/', check) app.router.add_route('GET', '/', check)
app.router.add_route('POST', '/', create) app.router.add_route('POST', '/', create)
client = yield from test_client(app) client = await test_client(app)
resp = yield from client.post('/') resp = await client.post('/')
assert 200 == resp.status assert 200 == resp.status
yield from resp.release()
resp = yield from client.get('/') resp = await client.get('/')
assert 200 == resp.status assert 200 == resp.status
yield from resp.release()
@asyncio.coroutine async def test_forget(make_app, test_client):
def test_forget(make_app, test_client):
@asyncio.coroutine async def index(request):
def index(request): session = await get_session(request)
session = yield from get_session(request)
return web.HTTPOk(text=session.get('AIOHTTP_SECURITY', '')) return web.HTTPOk(text=session.get('AIOHTTP_SECURITY', ''))
@asyncio.coroutine async def login(request):
def login(request):
response = web.HTTPFound(location='/') response = web.HTTPFound(location='/')
yield from remember(request, response, 'Andrew') await remember(request, response, 'Andrew')
return response return response
@asyncio.coroutine async def logout(request):
def logout(request):
response = web.HTTPFound('/') response = web.HTTPFound('/')
yield from forget(request, response) await forget(request, response)
return response return response
app = make_app() app = make_app()
@ -112,18 +95,16 @@ def test_forget(make_app, test_client):
app.router.add_route('POST', '/login', login) app.router.add_route('POST', '/login', login)
app.router.add_route('POST', '/logout', logout) app.router.add_route('POST', '/logout', logout)
client = yield from test_client(app) client = await test_client(app)
resp = yield from client.post('/login') resp = await client.post('/login')
assert 200 == resp.status assert 200 == resp.status
assert str(resp.url).endswith('/') assert str(resp.url).endswith('/')
txt = yield from resp.text() txt = await resp.text()
assert 'Andrew' == txt assert 'Andrew' == txt
yield from resp.release()
resp = yield from client.post('/logout') resp = await client.post('/logout')
assert 200 == resp.status assert 200 == resp.status
assert str(resp.url).endswith('/') assert str(resp.url).endswith('/')
txt = yield from resp.text() txt = await resp.text()
assert '' == txt assert '' == txt
yield from resp.release()