import asyncio
import socket
import unittest

import aiohttp
from aiohttp import web
from aiohttp_security import (remember, setup, forget,
                              AbstractAuthorizationPolicy)
from aiohttp_security.cookies_identity import CookiesIdentityPolicy
from aiohttp_security.api import IDENTITY_KEY


class Autz(AbstractAuthorizationPolicy):

    @asyncio.coroutine
    def permits(self, identity, permission, context=None):
        pass

    @asyncio.coroutine
    def authorized_userid(self, identity):
        pass


class TestCookiesIdentity(unittest.TestCase):

    def setUp(self):
        self.loop = asyncio.new_event_loop()
        asyncio.set_event_loop(None)
        self.client = aiohttp.ClientSession(loop=self.loop)

    def tearDown(self):
        self.client.close()
        self.loop.run_until_complete(self.handler.finish_connections())
        self.srv.close()
        self.loop.run_until_complete(self.srv.wait_closed())
        self.loop.close()

    def find_unused_port(self):
        s = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
        s.bind(('127.0.0.1', 0))
        port = s.getsockname()[1]
        s.close()
        return port

    @asyncio.coroutine
    def create_server(self):
        app = web.Application(loop=self.loop)
        setup(app, CookiesIdentityPolicy(), Autz())

        port = self.find_unused_port()
        self.handler = app.make_handler(
            debug=False, keep_alive_on=False)
        srv = yield from self.loop.create_server(
            self.handler, '127.0.0.1', port)
        url = "http://127.0.0.1:{}".format(port)
        self.srv = srv
        return app, srv, url

    def test_remember(self):

        @asyncio.coroutine
        def handler(request):
            response = web.Response()
            yield from remember(request, response, 'Andrew')
            return response

        @asyncio.coroutine
        def go():
            app, srv, url = yield from self.create_server()
            app.router.add_route('GET', '/', handler)
            resp = yield from self.client.get(url+'/')
            self.assertEqual(200, resp.status)
            self.assertEqual('Andrew',
                             self.client.cookies['AIOHTTP_SECURITY'].value)
            yield from resp.release()

        self.loop.run_until_complete(go())

    def test_identify(self):

        @asyncio.coroutine
        def create(request):
            response = web.Response()
            yield from remember(request, response, 'Andrew')
            return response

        @asyncio.coroutine
        def check(request):
            policy = request.app[IDENTITY_KEY]
            user_id = yield from policy.identify(request)
            self.assertEqual('Andrew', user_id)
            return web.Response()

        @asyncio.coroutine
        def go():
            app, srv, url = yield from self.create_server()
            app.router.add_route('GET', '/', check)
            app.router.add_route('POST', '/', create)
            resp = yield from self.client.post(url+'/')
            self.assertEqual(200, resp.status)
            yield from resp.release()
            resp = yield from self.client.get(url+'/')
            self.assertEqual(200, resp.status)
            yield from resp.release()

        self.loop.run_until_complete(go())

    def test_forget(self):

        @asyncio.coroutine
        def index(request):
            return web.Response()

        @asyncio.coroutine
        def login(request):
            response = web.HTTPFound(location='/')
            yield from remember(request, response, 'Andrew')
            return response

        @asyncio.coroutine
        def logout(request):
            response = web.HTTPFound(location='/')
            yield from forget(request, response)
            return response

        @asyncio.coroutine
        def go():
            app, srv, url = yield from self.create_server()
            app.router.add_route('GET', '/', index)
            app.router.add_route('POST', '/login', login)
            app.router.add_route('POST', '/logout', logout)
            resp = yield from self.client.post(url+'/login')
            self.assertEqual(200, resp.status)
            self.assertEqual(url+'/', resp.url)
            self.assertEqual('Andrew',
                             self.client.cookies['AIOHTTP_SECURITY'].value)
            yield from resp.release()
            resp = yield from self.client.post(url+'/logout')
            self.assertEqual(200, resp.status)
            self.assertEqual(url+'/', resp.url)
            self.assertEqual('', self.client.cookies['AIOHTTP_SECURITY'].value)
            yield from resp.release()

        self.loop.run_until_complete(go())