439 lines
14 KiB
Python
439 lines
14 KiB
Python
"""WebSocket protocol versions 13 and 8."""
|
||
|
||
import base64
|
||
import binascii
|
||
import collections
|
||
import hashlib
|
||
import json
|
||
import os
|
||
import random
|
||
import sys
|
||
from enum import IntEnum
|
||
from struct import Struct
|
||
|
||
from aiohttp import errors, hdrs
|
||
from aiohttp.log import ws_logger
|
||
|
||
__all__ = ('WebSocketParser', 'WebSocketWriter', 'do_handshake',
|
||
'WSMessage', 'WebSocketError', 'WSMsgType', 'WSCloseCode')
|
||
|
||
|
||
class WSCloseCode(IntEnum):
|
||
OK = 1000
|
||
GOING_AWAY = 1001
|
||
PROTOCOL_ERROR = 1002
|
||
UNSUPPORTED_DATA = 1003
|
||
INVALID_TEXT = 1007
|
||
POLICY_VIOLATION = 1008
|
||
MESSAGE_TOO_BIG = 1009
|
||
MANDATORY_EXTENSION = 1010
|
||
INTERNAL_ERROR = 1011
|
||
SERVICE_RESTART = 1012
|
||
TRY_AGAIN_LATER = 1013
|
||
|
||
|
||
ALLOWED_CLOSE_CODES = {int(i) for i in WSCloseCode}
|
||
|
||
|
||
class WSMsgType(IntEnum):
|
||
CONTINUATION = 0x0
|
||
TEXT = 0x1
|
||
BINARY = 0x2
|
||
PING = 0x9
|
||
PONG = 0xa
|
||
CLOSE = 0x8
|
||
CLOSED = 0x101
|
||
ERROR = 0x102
|
||
|
||
text = TEXT
|
||
binary = BINARY
|
||
ping = PING
|
||
pong = PONG
|
||
close = CLOSE
|
||
closed = CLOSED
|
||
error = ERROR
|
||
|
||
|
||
WS_KEY = b'258EAFA5-E914-47DA-95CA-C5AB0DC85B11'
|
||
|
||
|
||
UNPACK_LEN2 = Struct('!H').unpack_from
|
||
UNPACK_LEN3 = Struct('!Q').unpack_from
|
||
UNPACK_CLOSE_CODE = Struct('!H').unpack
|
||
PACK_LEN1 = Struct('!BB').pack
|
||
PACK_LEN2 = Struct('!BBH').pack
|
||
PACK_LEN3 = Struct('!BBQ').pack
|
||
PACK_CLOSE_CODE = Struct('!H').pack
|
||
MSG_SIZE = 2 ** 14
|
||
|
||
|
||
_WSMessageBase = collections.namedtuple('_WSMessageBase',
|
||
['type', 'data', 'extra'])
|
||
|
||
|
||
class WSMessage(_WSMessageBase):
|
||
def json(self, *, loads=json.loads):
|
||
"""Return parsed JSON data.
|
||
|
||
.. versionadded:: 0.22
|
||
"""
|
||
return loads(self.data)
|
||
|
||
@property
|
||
def tp(self):
|
||
return self.type
|
||
|
||
|
||
CLOSED_MESSAGE = WSMessage(WSMsgType.CLOSED, None, None)
|
||
|
||
|
||
class WebSocketError(Exception):
|
||
"""WebSocket protocol parser error."""
|
||
|
||
def __init__(self, code, message):
|
||
self.code = code
|
||
super().__init__(message)
|
||
|
||
|
||
def WebSocketParser(out, buf):
|
||
while True:
|
||
fin, opcode, payload = yield from parse_frame(buf)
|
||
|
||
if opcode == WSMsgType.CLOSE:
|
||
if len(payload) >= 2:
|
||
close_code = UNPACK_CLOSE_CODE(payload[:2])[0]
|
||
if close_code < 3000 and close_code not in ALLOWED_CLOSE_CODES:
|
||
raise WebSocketError(
|
||
WSCloseCode.PROTOCOL_ERROR,
|
||
'Invalid close code: {}'.format(close_code))
|
||
try:
|
||
close_message = payload[2:].decode('utf-8')
|
||
except UnicodeDecodeError as exc:
|
||
raise WebSocketError(
|
||
WSCloseCode.INVALID_TEXT,
|
||
'Invalid UTF-8 text message') from exc
|
||
msg = WSMessage(WSMsgType.CLOSE, close_code, close_message)
|
||
elif payload:
|
||
raise WebSocketError(
|
||
WSCloseCode.PROTOCOL_ERROR,
|
||
'Invalid close frame: {} {} {!r}'.format(
|
||
fin, opcode, payload))
|
||
else:
|
||
msg = WSMessage(WSMsgType.CLOSE, 0, '')
|
||
|
||
out.feed_data(msg, 0)
|
||
|
||
elif opcode == WSMsgType.PING:
|
||
out.feed_data(WSMessage(WSMsgType.PING, payload, ''), len(payload))
|
||
|
||
elif opcode == WSMsgType.PONG:
|
||
out.feed_data(WSMessage(WSMsgType.PONG, payload, ''), len(payload))
|
||
|
||
elif opcode not in (WSMsgType.TEXT, WSMsgType.BINARY):
|
||
raise WebSocketError(
|
||
WSCloseCode.PROTOCOL_ERROR,
|
||
"Unexpected opcode={!r}".format(opcode))
|
||
else:
|
||
# load text/binary
|
||
data = [payload]
|
||
|
||
while not fin:
|
||
fin, _opcode, payload = yield from parse_frame(buf, True)
|
||
|
||
# We can receive ping/close in the middle of
|
||
# text message, Case 5.*
|
||
if _opcode == WSMsgType.PING:
|
||
out.feed_data(
|
||
WSMessage(WSMsgType.PING, payload, ''), len(payload))
|
||
fin, _opcode, payload = yield from parse_frame(buf, True)
|
||
elif _opcode == WSMsgType.CLOSE:
|
||
if len(payload) >= 2:
|
||
close_code = UNPACK_CLOSE_CODE(payload[:2])[0]
|
||
if (close_code not in ALLOWED_CLOSE_CODES and
|
||
close_code < 3000):
|
||
raise WebSocketError(
|
||
WSCloseCode.PROTOCOL_ERROR,
|
||
'Invalid close code: {}'.format(close_code))
|
||
try:
|
||
close_message = payload[2:].decode('utf-8')
|
||
except UnicodeDecodeError as exc:
|
||
raise WebSocketError(
|
||
WSCloseCode.INVALID_TEXT,
|
||
'Invalid UTF-8 text message') from exc
|
||
msg = WSMessage(WSMsgType.CLOSE, close_code,
|
||
close_message)
|
||
elif payload:
|
||
raise WebSocketError(
|
||
WSCloseCode.PROTOCOL_ERROR,
|
||
'Invalid close frame: {} {} {!r}'.format(
|
||
fin, opcode, payload))
|
||
else:
|
||
msg = WSMessage(WSMsgType.CLOSE, 0, '')
|
||
|
||
out.feed_data(msg, 0)
|
||
fin, _opcode, payload = yield from parse_frame(buf, True)
|
||
|
||
if _opcode != WSMsgType.CONTINUATION:
|
||
raise WebSocketError(
|
||
WSCloseCode.PROTOCOL_ERROR,
|
||
'The opcode in non-fin frame is expected '
|
||
'to be zero, got {!r}'.format(_opcode))
|
||
else:
|
||
data.append(payload)
|
||
|
||
if opcode == WSMsgType.TEXT:
|
||
try:
|
||
text = b''.join(data).decode('utf-8')
|
||
out.feed_data(WSMessage(WSMsgType.TEXT, text, ''),
|
||
len(text))
|
||
except UnicodeDecodeError as exc:
|
||
raise WebSocketError(
|
||
WSCloseCode.INVALID_TEXT,
|
||
'Invalid UTF-8 text message') from exc
|
||
else:
|
||
data = b''.join(data)
|
||
out.feed_data(
|
||
WSMessage(WSMsgType.BINARY, data, ''), len(data))
|
||
|
||
|
||
native_byteorder = sys.byteorder
|
||
|
||
|
||
def _websocket_mask_python(mask, data):
|
||
"""Websocket masking function.
|
||
|
||
`mask` is a `bytes` object of length 4; `data` is a `bytes` object
|
||
of any length. Returns a `bytes` object of the same length as
|
||
`data` with the mask applied as specified in section 5.3 of RFC
|
||
6455.
|
||
|
||
This pure-python implementation may be replaced by an optimized
|
||
version when available.
|
||
|
||
"""
|
||
assert isinstance(data, bytearray), data
|
||
assert len(mask) == 4, mask
|
||
datalen = len(data)
|
||
if datalen == 0:
|
||
# everything work without this, but may be changed later in Python.
|
||
return bytearray()
|
||
data = int.from_bytes(data, native_byteorder)
|
||
mask = int.from_bytes(mask * (datalen // 4) + mask[: datalen % 4],
|
||
native_byteorder)
|
||
return (data ^ mask).to_bytes(datalen, native_byteorder)
|
||
|
||
|
||
if bool(os.environ.get('AIOHTTP_NO_EXTENSIONS')):
|
||
_websocket_mask = _websocket_mask_python
|
||
else:
|
||
try:
|
||
from ._websocket import _websocket_mask_cython
|
||
_websocket_mask = _websocket_mask_cython
|
||
except ImportError: # pragma: no cover
|
||
_websocket_mask = _websocket_mask_python
|
||
|
||
|
||
def parse_frame(buf, continuation=False):
|
||
"""Return the next frame from the socket."""
|
||
# read header
|
||
data = yield from buf.read(2)
|
||
first_byte, second_byte = data
|
||
|
||
fin = (first_byte >> 7) & 1
|
||
rsv1 = (first_byte >> 6) & 1
|
||
rsv2 = (first_byte >> 5) & 1
|
||
rsv3 = (first_byte >> 4) & 1
|
||
opcode = first_byte & 0xf
|
||
|
||
# frame-fin = %x0 ; more frames of this message follow
|
||
# / %x1 ; final frame of this message
|
||
# frame-rsv1 = %x0 ; 1 bit, MUST be 0 unless negotiated otherwise
|
||
# frame-rsv2 = %x0 ; 1 bit, MUST be 0 unless negotiated otherwise
|
||
# frame-rsv3 = %x0 ; 1 bit, MUST be 0 unless negotiated otherwise
|
||
if rsv1 or rsv2 or rsv3:
|
||
raise WebSocketError(
|
||
WSCloseCode.PROTOCOL_ERROR,
|
||
'Received frame with non-zero reserved bits')
|
||
|
||
if opcode > 0x7 and fin == 0:
|
||
raise WebSocketError(
|
||
WSCloseCode.PROTOCOL_ERROR,
|
||
'Received fragmented control frame')
|
||
|
||
if fin == 0 and opcode == WSMsgType.CONTINUATION and not continuation:
|
||
raise WebSocketError(
|
||
WSCloseCode.PROTOCOL_ERROR,
|
||
'Received new fragment frame with non-zero '
|
||
'opcode {!r}'.format(opcode))
|
||
|
||
has_mask = (second_byte >> 7) & 1
|
||
length = (second_byte) & 0x7f
|
||
|
||
# Control frames MUST have a payload length of 125 bytes or less
|
||
if opcode > 0x7 and length > 125:
|
||
raise WebSocketError(
|
||
WSCloseCode.PROTOCOL_ERROR,
|
||
"Control frame payload cannot be larger than 125 bytes")
|
||
|
||
# read payload
|
||
if length == 126:
|
||
data = yield from buf.read(2)
|
||
length = UNPACK_LEN2(data)[0]
|
||
elif length > 126:
|
||
data = yield from buf.read(8)
|
||
length = UNPACK_LEN3(data)[0]
|
||
|
||
if has_mask:
|
||
mask = yield from buf.read(4)
|
||
|
||
if length:
|
||
payload = yield from buf.read(length)
|
||
else:
|
||
payload = bytearray()
|
||
|
||
if has_mask:
|
||
payload = _websocket_mask(bytes(mask), payload)
|
||
|
||
return fin, opcode, payload
|
||
|
||
|
||
class WebSocketWriter:
|
||
|
||
def __init__(self, writer, *, use_mask=False, random=random.Random()):
|
||
self.writer = writer
|
||
self.use_mask = use_mask
|
||
self.randrange = random.randrange
|
||
|
||
def _send_frame(self, message, opcode):
|
||
"""Send a frame over the websocket with message as its payload."""
|
||
msg_length = len(message)
|
||
|
||
use_mask = self.use_mask
|
||
if use_mask:
|
||
mask_bit = 0x80
|
||
else:
|
||
mask_bit = 0
|
||
|
||
if msg_length < 126:
|
||
header = PACK_LEN1(0x80 | opcode, msg_length | mask_bit)
|
||
elif msg_length < (1 << 16):
|
||
header = PACK_LEN2(0x80 | opcode, 126 | mask_bit, msg_length)
|
||
else:
|
||
header = PACK_LEN3(0x80 | opcode, 127 | mask_bit, msg_length)
|
||
if use_mask:
|
||
mask = self.randrange(0, 0xffffffff)
|
||
mask = mask.to_bytes(4, 'big')
|
||
message = _websocket_mask(mask, bytearray(message))
|
||
self.writer.write(header + mask + message)
|
||
else:
|
||
if len(message) > MSG_SIZE:
|
||
self.writer.write(header)
|
||
self.writer.write(message)
|
||
else:
|
||
self.writer.write(header + message)
|
||
|
||
def pong(self, message=b''):
|
||
"""Send pong message."""
|
||
if isinstance(message, str):
|
||
message = message.encode('utf-8')
|
||
self._send_frame(message, WSMsgType.PONG)
|
||
|
||
def ping(self, message=b''):
|
||
"""Send ping message."""
|
||
if isinstance(message, str):
|
||
message = message.encode('utf-8')
|
||
self._send_frame(message, WSMsgType.PING)
|
||
|
||
def send(self, message, binary=False):
|
||
"""Send a frame over the websocket with message as its payload."""
|
||
if isinstance(message, str):
|
||
message = message.encode('utf-8')
|
||
if binary:
|
||
self._send_frame(message, WSMsgType.BINARY)
|
||
else:
|
||
self._send_frame(message, WSMsgType.TEXT)
|
||
|
||
def close(self, code=1000, message=b''):
|
||
"""Close the websocket, sending the specified code and message."""
|
||
if isinstance(message, str):
|
||
message = message.encode('utf-8')
|
||
self._send_frame(
|
||
PACK_CLOSE_CODE(code) + message, opcode=WSMsgType.CLOSE)
|
||
|
||
|
||
def do_handshake(method, headers, transport, protocols=()):
|
||
"""Prepare WebSocket handshake.
|
||
|
||
It return HTTP response code, response headers, websocket parser,
|
||
websocket writer. It does not perform any IO.
|
||
|
||
`protocols` is a sequence of known protocols. On successful handshake,
|
||
the returned response headers contain the first protocol in this list
|
||
which the server also knows.
|
||
|
||
"""
|
||
# WebSocket accepts only GET
|
||
if method.upper() != hdrs.METH_GET:
|
||
raise errors.HttpProcessingError(
|
||
code=405, headers=((hdrs.ALLOW, hdrs.METH_GET),))
|
||
|
||
if 'websocket' != headers.get(hdrs.UPGRADE, '').lower().strip():
|
||
raise errors.HttpBadRequest(
|
||
message='No WebSocket UPGRADE hdr: {}\n Can '
|
||
'"Upgrade" only to "WebSocket".'.format(headers.get(hdrs.UPGRADE)))
|
||
|
||
if 'upgrade' not in headers.get(hdrs.CONNECTION, '').lower():
|
||
raise errors.HttpBadRequest(
|
||
message='No CONNECTION upgrade hdr: {}'.format(
|
||
headers.get(hdrs.CONNECTION)))
|
||
|
||
# find common sub-protocol between client and server
|
||
protocol = None
|
||
if hdrs.SEC_WEBSOCKET_PROTOCOL in headers:
|
||
req_protocols = [str(proto.strip()) for proto in
|
||
headers[hdrs.SEC_WEBSOCKET_PROTOCOL].split(',')]
|
||
|
||
for proto in req_protocols:
|
||
if proto in protocols:
|
||
protocol = proto
|
||
break
|
||
else:
|
||
# No overlap found: Return no protocol as per spec
|
||
ws_logger.warning(
|
||
'Client protocols %r don’t overlap server-known ones %r',
|
||
req_protocols, protocols)
|
||
|
||
# check supported version
|
||
version = headers.get(hdrs.SEC_WEBSOCKET_VERSION, '')
|
||
if version not in ('13', '8', '7'):
|
||
raise errors.HttpBadRequest(
|
||
message='Unsupported version: {}'.format(version),
|
||
headers=((hdrs.SEC_WEBSOCKET_VERSION, '13'),))
|
||
|
||
# check client handshake for validity
|
||
key = headers.get(hdrs.SEC_WEBSOCKET_KEY)
|
||
try:
|
||
if not key or len(base64.b64decode(key)) != 16:
|
||
raise errors.HttpBadRequest(
|
||
message='Handshake error: {!r}'.format(key))
|
||
except binascii.Error:
|
||
raise errors.HttpBadRequest(
|
||
message='Handshake error: {!r}'.format(key)) from None
|
||
|
||
response_headers = [
|
||
(hdrs.UPGRADE, 'websocket'),
|
||
(hdrs.CONNECTION, 'upgrade'),
|
||
(hdrs.TRANSFER_ENCODING, 'chunked'),
|
||
(hdrs.SEC_WEBSOCKET_ACCEPT, base64.b64encode(
|
||
hashlib.sha1(key.encode() + WS_KEY).digest()).decode())]
|
||
|
||
if protocol:
|
||
response_headers.append((hdrs.SEC_WEBSOCKET_PROTOCOL, protocol))
|
||
|
||
# response code, headers, parser, writer, protocol
|
||
return (101,
|
||
response_headers,
|
||
WebSocketParser,
|
||
WebSocketWriter(transport),
|
||
protocol)
|