147 lines
4.1 KiB
Python
147 lines
4.1 KiB
Python
|
|
#
|
||
|
|
# Copyright BitBake Contributors
|
||
|
|
#
|
||
|
|
# SPDX-License-Identifier: GPL-2.0-only
|
||
|
|
#
|
||
|
|
|
||
|
|
import asyncio
|
||
|
|
import itertools
|
||
|
|
import json
|
||
|
|
from datetime import datetime
|
||
|
|
from .exceptions import ClientError, ConnectionClosedError
|
||
|
|
|
||
|
|
|
||
|
|
# The Python async server defaults to a 64K receive buffer, so we hardcode our
|
||
|
|
# maximum chunk size. It would be better if the client and server reported to
|
||
|
|
# each other what the maximum chunk sizes were, but that will slow down the
|
||
|
|
# connection setup with a round trip delay so I'd rather not do that unless it
|
||
|
|
# is necessary
|
||
|
|
DEFAULT_MAX_CHUNK = 32 * 1024
|
||
|
|
|
||
|
|
|
||
|
|
def chunkify(msg, max_chunk):
|
||
|
|
if len(msg) < max_chunk - 1:
|
||
|
|
yield "".join((msg, "\n"))
|
||
|
|
else:
|
||
|
|
yield "".join((json.dumps({"chunk-stream": None}), "\n"))
|
||
|
|
|
||
|
|
args = [iter(msg)] * (max_chunk - 1)
|
||
|
|
for m in map("".join, itertools.zip_longest(*args, fillvalue="")):
|
||
|
|
yield "".join(itertools.chain(m, "\n"))
|
||
|
|
yield "\n"
|
||
|
|
|
||
|
|
|
||
|
|
def json_serialize(obj):
|
||
|
|
if isinstance(obj, datetime):
|
||
|
|
return obj.isoformat()
|
||
|
|
raise TypeError("Type %s not serializeable" % type(obj))
|
||
|
|
|
||
|
|
|
||
|
|
class StreamConnection(object):
|
||
|
|
def __init__(self, reader, writer, timeout, max_chunk=DEFAULT_MAX_CHUNK):
|
||
|
|
self.reader = reader
|
||
|
|
self.writer = writer
|
||
|
|
self.timeout = timeout
|
||
|
|
self.max_chunk = max_chunk
|
||
|
|
|
||
|
|
@property
|
||
|
|
def address(self):
|
||
|
|
return self.writer.get_extra_info("peername")
|
||
|
|
|
||
|
|
async def send_message(self, msg):
|
||
|
|
for c in chunkify(json.dumps(msg, default=json_serialize), self.max_chunk):
|
||
|
|
self.writer.write(c.encode("utf-8"))
|
||
|
|
await self.writer.drain()
|
||
|
|
|
||
|
|
async def recv_message(self):
|
||
|
|
l = await self.recv()
|
||
|
|
|
||
|
|
m = json.loads(l)
|
||
|
|
if not m:
|
||
|
|
return m
|
||
|
|
|
||
|
|
if "chunk-stream" in m:
|
||
|
|
lines = []
|
||
|
|
while True:
|
||
|
|
l = await self.recv()
|
||
|
|
if not l:
|
||
|
|
break
|
||
|
|
lines.append(l)
|
||
|
|
|
||
|
|
m = json.loads("".join(lines))
|
||
|
|
|
||
|
|
return m
|
||
|
|
|
||
|
|
async def send(self, msg):
|
||
|
|
self.writer.write(("%s\n" % msg).encode("utf-8"))
|
||
|
|
await self.writer.drain()
|
||
|
|
|
||
|
|
async def recv(self):
|
||
|
|
if self.timeout < 0:
|
||
|
|
line = await self.reader.readline()
|
||
|
|
else:
|
||
|
|
try:
|
||
|
|
line = await asyncio.wait_for(self.reader.readline(), self.timeout)
|
||
|
|
except asyncio.TimeoutError:
|
||
|
|
raise ConnectionError("Timed out waiting for data")
|
||
|
|
|
||
|
|
if not line:
|
||
|
|
raise ConnectionClosedError("Connection closed")
|
||
|
|
|
||
|
|
line = line.decode("utf-8")
|
||
|
|
|
||
|
|
if not line.endswith("\n"):
|
||
|
|
raise ConnectionError("Bad message %r" % (line))
|
||
|
|
|
||
|
|
return line.rstrip()
|
||
|
|
|
||
|
|
async def close(self):
|
||
|
|
self.reader = None
|
||
|
|
if self.writer is not None:
|
||
|
|
self.writer.close()
|
||
|
|
self.writer = None
|
||
|
|
|
||
|
|
|
||
|
|
class WebsocketConnection(object):
|
||
|
|
def __init__(self, socket, timeout):
|
||
|
|
self.socket = socket
|
||
|
|
self.timeout = timeout
|
||
|
|
|
||
|
|
@property
|
||
|
|
def address(self):
|
||
|
|
return ":".join(str(s) for s in self.socket.remote_address)
|
||
|
|
|
||
|
|
async def send_message(self, msg):
|
||
|
|
await self.send(json.dumps(msg, default=json_serialize))
|
||
|
|
|
||
|
|
async def recv_message(self):
|
||
|
|
m = await self.recv()
|
||
|
|
return json.loads(m)
|
||
|
|
|
||
|
|
async def send(self, msg):
|
||
|
|
import websockets.exceptions
|
||
|
|
|
||
|
|
try:
|
||
|
|
await self.socket.send(msg)
|
||
|
|
except websockets.exceptions.ConnectionClosed:
|
||
|
|
raise ConnectionClosedError("Connection closed")
|
||
|
|
|
||
|
|
async def recv(self):
|
||
|
|
import websockets.exceptions
|
||
|
|
|
||
|
|
try:
|
||
|
|
if self.timeout < 0:
|
||
|
|
return await self.socket.recv()
|
||
|
|
|
||
|
|
try:
|
||
|
|
return await asyncio.wait_for(self.socket.recv(), self.timeout)
|
||
|
|
except asyncio.TimeoutError:
|
||
|
|
raise ConnectionError("Timed out waiting for data")
|
||
|
|
except websockets.exceptions.ConnectionClosed:
|
||
|
|
raise ConnectionClosedError("Connection closed")
|
||
|
|
|
||
|
|
async def close(self):
|
||
|
|
if self.socket is not None:
|
||
|
|
await self.socket.close()
|
||
|
|
self.socket = None
|