232 lines
8.0 KiB
Python
232 lines
8.0 KiB
Python
import binascii
|
|
import hashlib
|
|
from microdot import Request, Response
|
|
from microdot.microdot import MUTED_SOCKET_ERRORS, print_exception
|
|
from microdot.helpers import wraps
|
|
|
|
|
|
class WebSocketError(Exception):
|
|
"""Exception raised when an error occurs in a WebSocket connection."""
|
|
pass
|
|
|
|
|
|
class WebSocket:
|
|
"""A WebSocket connection object.
|
|
|
|
An instance of this class is sent to handler functions to manage the
|
|
WebSocket connection.
|
|
"""
|
|
CONT = 0
|
|
TEXT = 1
|
|
BINARY = 2
|
|
CLOSE = 8
|
|
PING = 9
|
|
PONG = 10
|
|
|
|
#: Specify the maximum message size that can be received when calling the
|
|
#: ``receive()`` method. Messages with payloads that are larger than this
|
|
#: size will be rejected and the connection closed. Set to 0 to disable
|
|
#: the size check (be aware of potential security issues if you do this),
|
|
#: or to -1 to use the value set in
|
|
#: ``Request.max_body_length``. The default is -1.
|
|
#:
|
|
#: Example::
|
|
#:
|
|
#: WebSocket.max_message_length = 4 * 1024 # up to 4KB messages
|
|
max_message_length = -1
|
|
|
|
def __init__(self, request):
|
|
self.request = request
|
|
self.closed = False
|
|
|
|
async def handshake(self):
|
|
response = self._handshake_response()
|
|
await self.request.sock[1].awrite(
|
|
b'HTTP/1.1 101 Switching Protocols\r\n')
|
|
await self.request.sock[1].awrite(b'Upgrade: websocket\r\n')
|
|
await self.request.sock[1].awrite(b'Connection: Upgrade\r\n')
|
|
await self.request.sock[1].awrite(
|
|
b'Sec-WebSocket-Accept: ' + response + b'\r\n\r\n')
|
|
|
|
async def receive(self):
|
|
"""Receive a message from the client."""
|
|
while True:
|
|
opcode, payload = await self._read_frame()
|
|
send_opcode, data = self._process_websocket_frame(opcode, payload)
|
|
if send_opcode: # pragma: no cover
|
|
await self.send(data, send_opcode)
|
|
elif data: # pragma: no branch
|
|
return data
|
|
|
|
async def send(self, data, opcode=None):
|
|
"""Send a message to the client.
|
|
|
|
:param data: the data to send, given as a string or bytes.
|
|
:param opcode: a custom frame opcode to use. If not given, the opcode
|
|
is ``TEXT`` or ``BINARY`` depending on the type of the
|
|
data.
|
|
"""
|
|
frame = self._encode_websocket_frame(
|
|
opcode or (self.TEXT if isinstance(data, str) else self.BINARY),
|
|
data)
|
|
await self.request.sock[1].awrite(frame)
|
|
|
|
async def close(self):
|
|
"""Close the websocket connection."""
|
|
if not self.closed: # pragma: no cover
|
|
self.closed = True
|
|
await self.send(b'', self.CLOSE)
|
|
|
|
def _handshake_response(self):
|
|
connection = False
|
|
upgrade = False
|
|
websocket_key = None
|
|
for header, value in self.request.headers.items():
|
|
h = header.lower()
|
|
if h == 'connection':
|
|
connection = True
|
|
if 'upgrade' not in value.lower():
|
|
return self.request.app.abort(400)
|
|
elif h == 'upgrade':
|
|
upgrade = True
|
|
if not value.lower() == 'websocket':
|
|
return self.request.app.abort(400)
|
|
elif h == 'sec-websocket-key':
|
|
websocket_key = value
|
|
if not connection or not upgrade or not websocket_key:
|
|
return self.request.app.abort(400)
|
|
d = hashlib.sha1(websocket_key.encode())
|
|
d.update(b'258EAFA5-E914-47DA-95CA-C5AB0DC85B11')
|
|
return binascii.b2a_base64(d.digest())[:-1]
|
|
|
|
@classmethod
|
|
def _parse_frame_header(cls, header):
|
|
fin = header[0] & 0x80
|
|
opcode = header[0] & 0x0f
|
|
if fin == 0 or opcode == cls.CONT: # pragma: no cover
|
|
raise WebSocketError('Continuation frames not supported')
|
|
has_mask = header[1] & 0x80
|
|
length = header[1] & 0x7f
|
|
if length == 126:
|
|
length = -2
|
|
elif length == 127:
|
|
length = -8
|
|
return fin, opcode, has_mask, length
|
|
|
|
def _process_websocket_frame(self, opcode, payload):
|
|
if opcode == self.TEXT:
|
|
payload = payload.decode()
|
|
elif opcode == self.BINARY:
|
|
pass
|
|
elif opcode == self.CLOSE:
|
|
raise WebSocketError('Websocket connection closed')
|
|
elif opcode == self.PING:
|
|
return self.PONG, payload
|
|
elif opcode == self.PONG: # pragma: no branch
|
|
return None, None
|
|
return None, payload
|
|
|
|
@classmethod
|
|
def _encode_websocket_frame(cls, opcode, payload):
|
|
frame = bytearray()
|
|
frame.append(0x80 | opcode)
|
|
if opcode == cls.TEXT:
|
|
payload = payload.encode()
|
|
if len(payload) < 126:
|
|
frame.append(len(payload))
|
|
elif len(payload) < (1 << 16):
|
|
frame.append(126)
|
|
frame.extend(len(payload).to_bytes(2, 'big'))
|
|
else:
|
|
frame.append(127)
|
|
frame.extend(len(payload).to_bytes(8, 'big'))
|
|
frame.extend(payload)
|
|
return frame
|
|
|
|
async def _read_frame(self):
|
|
header = await self.request.sock[0].read(2)
|
|
if len(header) != 2: # pragma: no cover
|
|
raise WebSocketError('Websocket connection closed')
|
|
fin, opcode, has_mask, length = self._parse_frame_header(header)
|
|
if length == -2:
|
|
length = await self.request.sock[0].read(2)
|
|
length = int.from_bytes(length, 'big')
|
|
elif length == -8:
|
|
length = await self.request.sock[0].read(8)
|
|
length = int.from_bytes(length, 'big')
|
|
max_allowed_length = Request.max_body_length \
|
|
if self.max_message_length == -1 else self.max_message_length
|
|
if length > max_allowed_length:
|
|
raise WebSocketError('Message too large')
|
|
if has_mask: # pragma: no cover
|
|
mask = await self.request.sock[0].read(4)
|
|
payload = await self.request.sock[0].read(length)
|
|
if has_mask: # pragma: no cover
|
|
payload = bytes(x ^ mask[i % 4] for i, x in enumerate(payload))
|
|
return opcode, payload
|
|
|
|
|
|
async def websocket_upgrade(request):
|
|
"""Upgrade a request handler to a websocket connection.
|
|
|
|
This function can be called directly inside a route function to process a
|
|
WebSocket upgrade handshake, for example after the user's credentials are
|
|
verified. The function returns the websocket object::
|
|
|
|
@app.route('/echo')
|
|
async def echo(request):
|
|
if not authenticate_user(request):
|
|
abort(401)
|
|
ws = await websocket_upgrade(request)
|
|
while True:
|
|
message = await ws.receive()
|
|
await ws.send(message)
|
|
"""
|
|
ws = WebSocket(request)
|
|
await ws.handshake()
|
|
|
|
@request.after_request
|
|
async def after_request(request, response):
|
|
return Response.already_handled
|
|
|
|
return ws
|
|
|
|
|
|
def websocket_wrapper(f, upgrade_function):
|
|
@wraps(f)
|
|
async def wrapper(request, *args, **kwargs):
|
|
ws = await upgrade_function(request)
|
|
try:
|
|
await f(request, ws, *args, **kwargs)
|
|
except OSError as exc:
|
|
if exc.errno not in MUTED_SOCKET_ERRORS: # pragma: no cover
|
|
raise
|
|
except WebSocketError:
|
|
pass
|
|
except Exception as exc:
|
|
print_exception(exc)
|
|
finally: # pragma: no cover
|
|
try:
|
|
await ws.close()
|
|
except Exception:
|
|
pass
|
|
return Response.already_handled
|
|
return wrapper
|
|
|
|
|
|
def with_websocket(f):
|
|
"""Decorator to make a route a WebSocket endpoint.
|
|
|
|
This decorator is used to define a route that accepts websocket
|
|
connections. The route then receives a websocket object as a second
|
|
argument that it can use to send and receive messages::
|
|
|
|
@app.route('/echo')
|
|
@with_websocket
|
|
async def echo(request, ws):
|
|
while True:
|
|
message = await ws.receive()
|
|
await ws.send(message)
|
|
"""
|
|
return websocket_wrapper(f, websocket_upgrade)
|