Import of the watch repository from Pebble

This commit is contained in:
Matthieu Jeanson 2024-12-12 16:43:03 -08:00 committed by Katharine Berry
commit 3b92768480
10334 changed files with 2564465 additions and 0 deletions

89
python_libs/pulse2/.gitignore vendored Normal file
View file

@ -0,0 +1,89 @@
# Byte-compiled / optimized / DLL files
__pycache__/
*.py[cod]
*$py.class
# C extensions
*.so
# Distribution / packaging
.Python
env/
build/
develop-eggs/
dist/
downloads/
eggs/
.eggs/
lib/
lib64/
parts/
sdist/
var/
*.egg-info/
.installed.cfg
*.egg
# PyInstaller
# Usually these files are written by a python script from a template
# before PyInstaller builds the exe, so as to inject date/other infos into it.
*.manifest
*.spec
# Installer logs
pip-log.txt
pip-delete-this-directory.txt
# Unit test / coverage reports
htmlcov/
.tox/
.coverage
.coverage.*
.cache
nosetests.xml
coverage.xml
*,cover
.hypothesis/
# Translations
*.mo
*.pot
# Django stuff:
*.log
local_settings.py
# Flask stuff:
instance/
.webassets-cache
# Scrapy stuff:
.scrapy
# Sphinx documentation
docs/_build/
# PyBuilder
target/
# IPython Notebook
.ipynb_checkpoints
# pyenv
.python-version
# celery beat schedule file
celerybeat-schedule
# dotenv
.env
# virtualenv
venv/
ENV/
# Spyder project settings
.spyderproject
# Rope project settings
.ropeproject

View file

@ -0,0 +1,6 @@
pebble.pulse2
=============
pulse2 is a Python implementation of the PULSEv2 protocol suite.
https://pebbletechnology.atlassian.net/wiki/display/DEV/PULSEv2+Protocol+Suite

View file

@ -0,0 +1,15 @@
# Copyright 2024 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
__import__('pkg_resources').declare_namespace(__name__)

View file

@ -0,0 +1,24 @@
# Copyright 2024 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import absolute_import
from . import link, transports
# Public aliases for the classes that users will interact with directly.
from .link import Interface
link.Link.register_transport(
'best-effort', transports.BestEffortApplicationTransport)
link.Link.register_transport('reliable', transports.ReliableTransport)

View file

@ -0,0 +1,37 @@
# Copyright 2024 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
class PulseException(Exception):
pass
class TTYAutodetectionUnavailable(PulseException):
pass
class ReceiveQueueEmpty(PulseException):
pass
class TransportNotReady(PulseException):
pass
class SocketClosed(PulseException):
pass
class AlreadyInProgressError(PulseException):
'''Another operation is already in progress.
'''

View file

@ -0,0 +1,148 @@
# Copyright 2024 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
'''
PULSEv2 Framing
This module handles encoding and decoding of datagrams in PULSEv2 frames: flag
delimiters, transparency encoding and Frame Check Sequence. The content of the
datagrams themselves are not examined or parsed.
'''
from __future__ import absolute_import
import binascii
import struct
try:
import queue
except ImportError: # Py2
import Queue as queue
from cobs import cobs
FLAG = 0x55
CRC32_RESIDUE = binascii.crc32(b'\0' * 4)
class FramingException(Exception):
pass
class DecodeError(FramingException):
pass
class CorruptFrame(FramingException):
pass
class FrameSplitter(object):
'''Takes a byte stream and partitions it into frames.
Empty frames (two consecutive flag bytes) are silently discarded.
No transparency conversion is applied to the contents of the frames.
FrameSplitter objects support iteration for retrieving split frames.
>>> splitter = FrameSplitter()
>>> splitter.write(b'\x55foo\x55bar\x55')
>>> list(splitter)
[b'foo', b'bar']
'''
def __init__(self, max_frame_length=0):
self.frames = queue.Queue()
self.input_buffer = bytearray()
self.max_frame_length = max_frame_length
self.waiting_for_sync = True
def write(self, data):
'''Write bytes into the splitter for processing.
'''
for char in bytearray(data):
if self.waiting_for_sync:
if char == FLAG:
self.waiting_for_sync = False
else:
if char == FLAG:
if self.input_buffer:
self.frames.put_nowait(bytes(self.input_buffer))
self.input_buffer = bytearray()
else:
if (not self.max_frame_length or
len(self.input_buffer) < self.max_frame_length):
self.input_buffer.append(char)
else:
self.input_buffer = bytearray()
self.waiting_for_sync = True
def __iter__(self):
while True:
try:
yield self.frames.get_nowait()
except queue.Empty:
return
def decode_transparency(frame_bytes):
'''Decode the transparency encoding applied to a PULSEv2 frame.
Returns the decoded frame, or raises `DecodeError`.
'''
frame_bytes = bytearray(frame_bytes)
if FLAG in frame_bytes:
raise DecodeError("flag byte in encoded frame")
try:
return cobs.decode(bytes(frame_bytes.replace(b'\0', bytearray([FLAG]))))
except cobs.DecodeError as e:
raise DecodeError(str(e))
def strip_fcs(frame_bytes):
'''Validates the FCS in a PULSEv2 frame.
The frame is returned with the FCS removed if the FCS check passes.
A `CorruptFrame` exception is raised if the FCS check fails.
The frame must not be transparency-encoded.
'''
if len(frame_bytes) <= 4:
raise CorruptFrame('frame too short')
if binascii.crc32(frame_bytes) != CRC32_RESIDUE:
raise CorruptFrame('FCS check failure')
return frame_bytes[:-4]
def decode_frame(frame_bytes):
'''Decode and validate a PULSEv2-encoded frame.
Returns the datagram extracted from the frame, or raises a
`FramingException` or subclass if there was an error decoding the frame.
'''
return strip_fcs(decode_transparency(frame_bytes))
def encode_frame(datagram):
'''Encode a datagram in a PULSEv2 frame.
'''
datagram = bytearray(datagram)
fcs = binascii.crc32(datagram) & 0xffffffff
fcs_bytes = struct.pack('<I', fcs)
datagram.extend(fcs_bytes)
flag = bytearray([FLAG])
frame = cobs.encode(bytes(datagram)).replace(flag, b'\0')
return flag + frame + flag

View file

@ -0,0 +1,314 @@
# Copyright 2024 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import absolute_import
import logging
import threading
import serial
from . import exceptions, framing, ppp, transports
from . import logging as pulse2_logging
from . import pcap_file
try:
import pyftdi.serialext
except ImportError:
pass
logger = logging.getLogger(__name__)
logger.addHandler(logging.NullHandler())
DBGSERIAL_PORT_SETTINGS = dict(baudrate=1000000, timeout=0.1,
interCharTimeout=0.0001)
def get_dbgserial_tty():
# Local import so that we only depend on this package if we're attempting
# to autodetect the TTY. This package isn't always available (e.g., MFG),
# so we don't want it to be required.
try:
import pebble_tty
return pebble_tty.find_dbgserial_tty()
except ImportError:
raise exceptions.TTYAutodetectionUnavailable
class Interface(object):
'''The PULSEv2 lower data-link layer.
An Interface object is roughly analogous to a network interface,
somewhat like an Ethernet port. It provides connectionless service
with PULSEv2 framing, which upper layers build upon to provide
connection-oriented service.
An Interface is bound to an I/O stream, such as a Serial port, and
remains open until either the Interface is explicitly closed or the
underlying I/O stream is closed from underneath it.
'''
def __init__(self, iostream, capture_stream=None):
self.logger = pulse2_logging.TaggedAdapter(
logger, {'tag': type(self).__name__})
self.iostream = iostream
self.closed = False
self.close_lock = threading.RLock()
self.default_packet_handler_cb = None
self.sockets = {}
self.pcap = None
if capture_stream:
self.pcap = pcap_file.PcapWriter(
capture_stream, pcap_file.LINKTYPE_PPP_WITH_DIR)
self.receive_thread = threading.Thread(target=self.receive_loop)
self.receive_thread.daemon = True
self.receive_thread.start()
self.simplex_transport = transports.SimplexTransport(self)
self._link = None
self.link_available = threading.Event()
self.lcp = ppp.LinkControlProtocol(self)
self.lcp.on_link_up = self.on_link_up
self.lcp.on_link_down = self.on_link_down
self.lcp.up()
self.lcp.open()
@classmethod
def open_dbgserial(cls, url=None, capture_stream=None):
if url is None:
url = get_dbgserial_tty()
elif url == 'qemu':
url = 'socket://localhost:12345'
ser = serial.serial_for_url(url, **DBGSERIAL_PORT_SETTINGS)
if url.startswith('socket://'):
# interCharTimeout doesn't apply to sockets, so shrink the receive
# timeout to compensate.
ser.timeout = 0.5
ser._socket.settimeout(0.5)
return cls(ser, capture_stream)
def connect(self, protocol):
'''Open a link-layer socket for sending and receiving packets
of a specific protocol number.
'''
if protocol in self.sockets and not self.sockets[protocol].closed:
raise ValueError('A socket is already bound '
'to protocol 0x%04x' % protocol)
self.sockets[protocol] = socket = InterfaceSocket(self, protocol)
return socket
def unregister_socket(self, protocol):
'''Used by InterfaceSocket objets to unregister themselves when
closing.
'''
try:
del self.sockets[protocol]
except KeyError:
pass
def receive_loop(self):
splitter = framing.FrameSplitter()
while True:
if self.closed:
self.logger.info('Interface closed; receive loop exiting')
break
try:
splitter.write(self.iostream.read(1))
except IOError:
if self.closed:
self.logger.info('Interface closed; receive loop exiting')
else:
self.logger.exception('Unexpected error while reading '
'from iostream')
self._down()
break
for frame in splitter:
try:
datagram = framing.decode_frame(frame)
if self.pcap:
# Prepend pseudo-header meaning "received by this host"
self.pcap.write_packet(b'\0' + datagram)
protocol, information = ppp.unencapsulate(datagram)
if protocol in self.sockets:
self.sockets[protocol].handle_packet(information)
else:
# TODO LCP Protocol-Reject
self.logger.info('Protocol-reject: %04X', protocol)
except (framing.DecodeError, framing.CorruptFrame):
pass
def send_packet(self, protocol, packet):
if self.closed:
raise ValueError('I/O operation on closed interface')
datagram = ppp.encapsulate(protocol, packet)
if self.pcap:
# Prepend pseudo-header meaning "sent by this host"
self.pcap.write_packet(b'\x01' + datagram)
self.iostream.write(framing.encode_frame(datagram))
def close_all_sockets(self):
# Iterating over a copy of sockets since socket.close() can call
# unregister_socket, which modifies the socket dict. Modifying
# a dict during iteration is not allowed, so the iteration is
# completed (by making the copy) before modification can begin.
for socket in list(self.sockets.values()):
socket.close()
def close(self):
with self.close_lock:
if self.closed:
return
self.lcp.shutdown()
self.close_all_sockets()
self._down()
if self.pcap:
self.pcap.close()
def _down(self):
'''The lower layer (iostream) is down. Bring down the interface.
'''
with self.close_lock:
self.closed = True
self.close_all_sockets()
self.lcp.down()
self.simplex_transport.down()
self.iostream.close()
def on_link_up(self):
# FIXME PBL-34320 proper MTU/MRU support
self._link = Link(self, mtu=1500)
# Test whether the link is ready to carry traffic
self.lcp.ping(self._ping_done)
def _ping_done(self, ping_check_succeeded):
if ping_check_succeeded:
self.link_available.set()
else:
self.lcp.restart()
def on_link_down(self):
self.link_available.clear()
self._link.down()
self._link = None
def get_link(self, timeout=60.0):
'''Get the opened Link object for this interface.
This function will block waiting for the Link to be available.
It will return `None` if the timeout expires before the link
is available.
'''
if self.closed:
raise ValueError('No link available on closed interface')
if self.link_available.wait(timeout):
assert self._link is not None
return self._link
class InterfaceSocket(object):
'''A socket for sending and receiving link-layer packets over a
PULSE2 interface.
Callbacks can be registered on the socket by assigning callables to
the appropriate attributes on the socket object. Callbacks can be
unregistered by setting the attributes back to `None`.
Available callbacks:
- `on_packet(information)`
- `on_close()`
'''
on_packet = None
on_close = None
def __init__(self, interface, protocol):
self.interface = interface
self.protocol = protocol
self.closed = False
def __enter__(self):
return self
def __exit__(self, type, value, traceback):
self.close()
def send(self, information):
if self.closed:
raise exceptions.SocketClosed('I/O operation on closed socket')
self.interface.send_packet(self.protocol, information)
def handle_packet(self, information):
if self.on_packet and not self.closed:
self.on_packet(information)
def close(self):
if self.closed:
return
self.closed = True
if self.on_close:
self.on_close()
self.interface.unregister_socket(self.protocol)
self.on_packet = None
self.on_close = None
class Link(object):
'''The connectionful portion of a PULSE2 interface.
'''
TRANSPORTS = {}
on_close = None
@classmethod
def register_transport(cls, name, factory):
'''Register a PULSE transport.
'''
if name in cls.TRANSPORTS:
raise ValueError('transport name %r is already registered '
'with %r' % (name, cls.TRANSPORTS[name]))
cls.TRANSPORTS[name] = factory
def __init__(self, interface, mtu):
self.logger = pulse2_logging.TaggedAdapter(
logger, {'tag': type(self).__name__})
self.interface = interface
self.closed = False
self.mtu = mtu
self.transports = {}
for name, factory in self.TRANSPORTS.iteritems():
transport = factory(interface, mtu)
self.transports[name] = transport
def open_socket(self, transport, port, timeout=30.0):
if self.closed:
raise ValueError('Cannot open socket on closed Link')
if transport not in self.transports:
raise KeyError('Unknown transport %r' % transport)
return self.transports[transport].open_socket(port, timeout)
def down(self):
self.closed = True
if self.on_close:
self.on_close()
for transport in self.transports.itervalues():
transport.down()

View file

@ -0,0 +1,31 @@
# Copyright 2024 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import absolute_import
import logging
class TaggedAdapter(logging.LoggerAdapter):
'''Annotates all log messages with a "[tag]" prefix.
The value of the tag is specified in the dict argument passed into
the adapter's constructor.
>>> logger = logging.getLogger(__name__)
>>> adapter = TaggedAdapter(logger, {'tag': 'tag value'})
'''
def process(self, msg, kwargs):
return '[%s] %s' % (self.extra['tag'], msg), kwargs

View file

@ -0,0 +1,68 @@
# Copyright 2024 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
'''Writer for Libpcap capture files
https://wiki.wireshark.org/Development/LibpcapFileFormat
'''
from __future__ import absolute_import
import struct
import threading
import time
LINKTYPE_PPP_WITH_DIR = 204
class PcapWriter(object):
def __init__(self, outfile, linktype):
self.lock = threading.Lock()
self.outfile = outfile
self._write_pcap_header(linktype)
def close(self):
with self.lock:
self.outfile.close()
def __enter__(self):
return self
def __exit__(self, *args):
self.close()
def _write_pcap_header(self, linktype):
header = struct.pack('!IHHiIII',
0xa1b2c3d4, # guint32 magic_number
2, # guint16 version_major
4, # guint16 version_minor
0, # guint32 thiszone
0, # guint32 sigfigs (unused)
65535, # guint32 snaplen
linktype) # guint32 network
self.outfile.write(header)
def write_packet(self, data, timestamp=None, orig_len=None):
assert len(data) <= 65535
if timestamp is None:
timestamp = time.time()
if orig_len is None:
orig_len = len(data)
ts_seconds = int(timestamp)
ts_usec = int((timestamp - ts_seconds) * 1000000)
header = struct.pack('!IIII', ts_seconds, ts_usec, len(data), orig_len)
with self.lock:
self.outfile.write(header + data)

View file

@ -0,0 +1,201 @@
# Copyright 2024 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
'''PULSE Control Message Protocol
'''
from __future__ import absolute_import
import codecs
import collections
import enum
import logging
import struct
import threading
from . import exceptions
from . import logging as pulse2_logging
logger = logging.getLogger(__name__)
logger.addHandler(logging.NullHandler())
class ParseError(exceptions.PulseException):
pass
@enum.unique
class PCMPCode(enum.Enum):
Echo_Request = 1
Echo_Reply = 2
Discard_Request = 3
Port_Closed = 129
Unknown_Code = 130
class PCMPPacket(collections.namedtuple('PCMPPacket', 'code information')):
__slots__ = ()
@classmethod
def parse(cls, packet):
packet = bytes(packet)
if len(packet) < 1:
raise ParseError('packet too short')
return cls(code=struct.unpack('B', packet[0])[0],
information=packet[1:])
@staticmethod
def build(code, information):
return struct.pack('B', code) + bytes(information)
class PulseControlMessageProtocol(object):
'''This protocol is unique in that it is logically part of the
transport but is layered on top of the transport over the wire.
To keep from needing to create a new thread just for reading from
the socket, the implementation acts both like a socket and protocol
all in one.
'''
PORT = 0x0001
on_port_closed = None
@classmethod
def bind(cls, transport):
return transport.open_socket(cls.PORT, factory=cls)
def __init__(self, transport, port):
assert port == self.PORT
self.logger = pulse2_logging.TaggedAdapter(
logger, {'tag': 'PCMP(%s)' % (type(transport).__name__)})
self.transport = transport
self.closed = False
self.ping_lock = threading.RLock()
self.ping_cb = None
self.ping_attempts_remaining = 0
self.ping_timer = None
def close(self):
if self.closed:
return
with self.ping_lock:
self.ping_cb = None
if self.ping_timer:
self.ping_timer.cancel()
self.closed = True
self.transport.unregister_socket(self.PORT)
def send_unknown_code(self, bad_code):
self.transport.send(self.PORT, PCMPPacket.build(
PCMPCode.Unknown_Code.value, struct.pack('B', bad_code)))
def send_echo_request(self, data):
self.transport.send(self.PORT, PCMPPacket.build(
PCMPCode.Echo_Request.value, data))
def send_echo_reply(self, data):
self.transport.send(self.PORT, PCMPPacket.build(
PCMPCode.Echo_Reply.value, data))
def on_receive(self, raw_packet):
try:
packet = PCMPPacket.parse(raw_packet)
except ParseError:
self.logger.exception('Received malformed packet')
return
try:
code = PCMPCode(packet.code)
except ValueError:
self.logger.error('Received packet with unknown code %d',
packet.code)
self.send_unknown_code(packet.code)
return
if code == PCMPCode.Discard_Request:
pass
elif code == PCMPCode.Echo_Request:
self.send_echo_reply(packet.information)
elif code == PCMPCode.Echo_Reply:
with self.ping_lock:
if self.ping_cb:
self.ping_timer.cancel()
self.ping_cb(True)
self.ping_cb = None
self.logger.debug('Echo-Reply: %s',
codecs.encode(packet.information, 'hex'))
elif code == PCMPCode.Port_Closed:
if len(packet.information) == 2:
if self.on_port_closed:
closed_port, = struct.unpack('!H', packet.information)
self.on_port_closed(closed_port)
else:
self.logger.error(
'Remote peer sent malformed Port-Closed packet: %s',
codecs.encode(packet.information, 'hex'))
elif code == PCMPCode.Unknown_Code:
if len(packet.information) == 1:
self.logger.error('Remote peer sent Unknown-Code(%d) packet',
struct.unpack('B', packet.information)[0])
else:
self.logger.error(
'Remote peer sent malformed Unknown-Code packet: %s',
codecs.encode(packet.information, 'hex'))
else:
assert False, 'Known code not handled'
def ping(self, result_cb, attempts=3, timeout=1.0):
'''Test the link quality by sending Echo-Request packets and
listening for Echo-Reply packets from the remote peer.
The ping is performed asynchronously. The `result_cb` callable
will be called when the ping completes. It will be called with
a single positional argument: a truthy value if the remote peer
responded to the ping, or a falsy value if all ping attempts
timed out.
'''
if attempts < 1:
raise ValueError('attempts must be positive')
if timeout <= 0:
raise ValueError('timeout must be positive')
with self.ping_lock:
if self.ping_cb:
raise exceptions.AlreadyInProgressError(
'another ping is currently in progress')
self.ping_cb = result_cb
self.ping_attempts_remaining = attempts - 1
self.ping_timeout = timeout
self.send_echo_request(b'')
self.ping_timer = threading.Timer(timeout,
self._ping_timer_expired)
self.ping_timer.daemon = True
self.ping_timer.start()
def _ping_timer_expired(self):
with self.ping_lock:
if not self.ping_cb:
# The Echo-Reply packet must have won the race
return
if self.ping_attempts_remaining:
self.ping_attempts_remaining -= 1
self.send_echo_request(b'')
self.ping_timer = threading.Timer(self.ping_timeout,
self._ping_timer_expired)
self.ping_timer.daemon = True
self.ping_timer.start()
else:
self.ping_cb(False)
self.ping_cb = None

File diff suppressed because it is too large Load diff

View file

@ -0,0 +1,68 @@
# Copyright 2024 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import absolute_import, division
import math
class OnlineStatistics(object):
'''Calculates various statistical properties of a data series
iteratively, without keeping the data items in memory.
Available statistics:
- Count
- Min
- Max
- Mean
- Variance
- Standard Deviation
The variance calculation algorithm is taken from
https://en.wikipedia.org/w/index.php?title=Algorithms_for_calculating_variance&oldid=715886413#Online_algorithm
'''
def __init__(self):
self.count = 0
self.min = float('nan')
self.max = float('nan')
self.mean = 0.0
self.M2 = 0.0
def update(self, datum):
self.count += 1
if self.count == 1:
self.min = datum
self.max = datum
else:
self.min = min(self.min, datum)
self.max = max(self.max, datum)
delta = datum - self.mean
self.mean += delta / self.count
self.M2 += delta * (datum - self.mean)
@property
def variance(self):
if self.count < 2:
return float('nan')
return self.M2 / (self.count - 1)
@property
def stddev(self):
return math.sqrt(self.variance)
def __str__(self):
return 'min/avg/max/stddev = {:.03f}/{:.03f}/{:.03f}/{:.03f}'.format(
self.min, self.mean, self.max, self.stddev)

View file

@ -0,0 +1,628 @@
# Copyright 2024 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import absolute_import
import logging
import threading
import time
try:
import queue
except ImportError:
import Queue as queue
import construct
from . import exceptions
from . import logging as pulse2_logging
from . import pcmp
from . import ppp
from . import stats
logger = logging.getLogger(__name__)
logger.addHandler(logging.NullHandler())
class Socket(object):
'''A socket for sending and receiving packets over a single port
of a PULSE2 transport.
'''
def __init__(self, transport, port):
self.transport = transport
self.port = port
self.closed = False
self.receive_queue = queue.Queue()
def on_receive(self, packet):
self.receive_queue.put((True, packet))
def receive(self, block=True, timeout=None):
if self.closed:
raise exceptions.SocketClosed('I/O operation on closed socket')
try:
info_good, info = self.receive_queue.get(block, timeout)
if not info_good:
assert self.closed
raise exceptions.SocketClosed('Socket closed during receive')
return info
except queue.Empty:
raise exceptions.ReceiveQueueEmpty
def send(self, information):
if self.closed:
raise exceptions.SocketClosed('I/O operation on closed socket')
self.transport.send(self.port, information)
def close(self):
if self.closed:
return
self.closed = True
self.transport.unregister_socket(self.port)
# Wake up the thread blocking on a receive (if any) so that it
# can abort the receive quickly.
self.receive_queue.put((False, None))
@property
def mtu(self):
return self.transport.mtu
class TransportControlProtocol(ppp.ControlProtocol):
def __init__(self, interface, transport, ncp_protocol, display_name=None):
ppp.ControlProtocol.__init__(self, display_name)
self.interface = interface
self.ncp_protocol = ncp_protocol
self.transport = transport
def up(self):
ppp.ControlProtocol.up(self, self.interface.connect(self.ncp_protocol))
def this_layer_up(self, *args):
self.transport.this_layer_up()
def this_layer_down(self, *args):
self.transport.this_layer_down()
BestEffortPacket = construct.Struct('BestEffortPacket', # noqa
construct.UBInt16('port'),
construct.UBInt16('length'),
construct.Field('information', lambda ctx: ctx.length - 4),
ppp.OptionalGreedyString('padding'),
)
class BestEffortTransportBase(object):
def __init__(self, interface, link_mtu):
self.logger = pulse2_logging.TaggedAdapter(
logger, {'tag': type(self).__name__})
self.sockets = {}
self.closed = False
self._mtu = link_mtu - 4
self.link_socket = interface.connect(self.PROTOCOL_NUMBER)
self.link_socket.on_packet = self.packet_received
def send(self, port, information):
if len(information) > self.mtu:
raise ValueError('Packet length (%d) exceeds transport MTU (%d)' % (
len(information), self.mtu))
packet = BestEffortPacket.build(construct.Container(
port=port, length=len(information)+4,
information=information, padding=b''))
self.link_socket.send(packet)
def packet_received(self, packet):
if self.closed:
self.logger.warning('Received packet on closed transport')
return
try:
fields = BestEffortPacket.parse(packet)
except (construct.ConstructError, ValueError):
self.logger.exception('Received malformed packet')
return
if len(fields.information) + 4 != fields.length:
self.logger.error('Received truncated or corrupt packet '
'(expected %d, got %d data bytes)',
fields.length-4, len(fields.information))
return
if fields.port in self.sockets:
self.sockets[fields.port].on_receive(fields.information)
else:
self.logger.warning('Received packet for unopened port %04X',
fields.port)
def open_socket(self, port, factory=Socket):
if self.closed:
raise ValueError('Cannot open socket on closed transport')
if port in self.sockets and not self.sockets[port].closed:
raise KeyError('Another socket is already opened '
'on port 0x%04x' % port)
socket = factory(self, port)
self.sockets[port] = socket
return socket
def unregister_socket(self, port):
del self.sockets[port]
def down(self):
'''Called by the Link when the link layer goes down.
This closes the Transport object. Once closed, the Transport
cannot be reopened.
'''
self.closed = True
self.close_all_sockets()
self.link_socket.close()
def close_all_sockets(self):
# A socket could try to unregister itself when closing, which
# would modify the sockets dict. Make a copy of the sockets
# collection before closing them so that we are not iterating
# over the dict when it could get modified.
for socket in list(self.sockets.values()):
socket.close()
self.sockets = {}
@property
def mtu(self):
return self._mtu
class BestEffortApplicationTransport(BestEffortTransportBase):
NCP_PROTOCOL_NUMBER = 0xBA29
PROTOCOL_NUMBER = 0x3A29
def __init__(self, interface, link_mtu):
BestEffortTransportBase.__init__(self, interface=interface,
link_mtu=link_mtu)
self.opened = threading.Event()
self.ncp = TransportControlProtocol(
interface=interface, transport=self,
ncp_protocol=self.NCP_PROTOCOL_NUMBER,
display_name='BestEffortControlProtocol')
self.ncp.up()
self.ncp.open()
def this_layer_up(self):
# We can't let PCMP bind itself using the public open_socket
# method as the method will block until self.opened is set, but
# it won't be set until we use PCMP Echo to test that the
# transport is ready to carry traffic. So we must manually bind
# the port without waiting.
self.pcmp = pcmp.PulseControlMessageProtocol(
self, pcmp.PulseControlMessageProtocol.PORT)
self.sockets[pcmp.PulseControlMessageProtocol.PORT] = self.pcmp
self.pcmp.on_port_closed = self.on_port_closed
self.pcmp.ping(self._ping_done)
def _ping_done(self, ping_check_succeeded):
# Don't need to do anything in the success case as receiving
# any packet is enough to set the transport as Opened.
if not ping_check_succeeded:
self.logger.warning('Ping check failed. Restarting transport.')
self.ncp.restart()
def this_layer_down(self):
self.opened.clear()
self.close_all_sockets()
def send(self, *args, **kwargs):
if self.closed:
raise exceptions.TransportNotReady(
'I/O operation on closed transport')
if not self.ncp.is_Opened():
raise exceptions.TransportNotReady(
'I/O operation before transport is opened')
BestEffortTransportBase.send(self, *args, **kwargs)
def packet_received(self, packet):
if self.ncp.is_Opened():
self.opened.set()
BestEffortTransportBase.packet_received(self, packet)
else:
self.logger.warning('Received packet before the transport is open. '
'Discarding.')
def open_socket(self, port, timeout=30.0, factory=Socket):
if not self.opened.wait(timeout):
return None
return BestEffortTransportBase.open_socket(self, port, factory)
def down(self):
self.ncp.down()
BestEffortTransportBase.down(self)
def on_port_closed(self, closed_port):
self.logger.info('Remote peer says port 0x%04X is closed; '
'closing socket', closed_port)
try:
self.sockets[closed_port].close()
except KeyError:
self.logger.exception('No socket is open on port 0x%04X!',
closed_port)
class SimplexTransport(BestEffortTransportBase):
PROTOCOL_NUMBER = 0x5021
def __init__(self, interface):
BestEffortTransportBase.__init__(self, interface=interface, link_mtu=0)
def send(self, *args, **kwargs):
raise NotImplementedError
@property
def mtu(self):
return 0
ReliableInfoPacket = construct.Struct('ReliableInfoPacket', # noqa
# BitStructs are parsed MSBit-first
construct.EmbeddedBitStruct(
construct.BitField('sequence_number', 7), # N(S) in LAPB
construct.Const(construct.Bit('discriminator'), 0),
construct.BitField('ack_number', 7), # N(R) in LAPB
construct.Flag('poll'),
),
construct.UBInt16('port'),
construct.UBInt16('length'),
construct.Field('information', lambda ctx: ctx.length - 6),
ppp.OptionalGreedyString('padding'),
)
ReliableSupervisoryPacket = construct.BitStruct(
'ReliableSupervisoryPacket',
construct.Const(construct.Nibble('reserved'), 0b0000),
construct.Enum(construct.BitField('kind', 2), # noqa
RR=0b00,
RNR=0b01,
REJ=0b10,
),
construct.Const(construct.BitField('discriminator', 2), 0b01),
construct.BitField('ack_number', 7), # N(R) in LAPB
construct.Flag('poll'),
construct.Alias('final', 'poll'),
)
def build_reliable_info_packet(sequence_number, ack_number, poll,
port, information):
return ReliableInfoPacket.build(construct.Container(
sequence_number=sequence_number, ack_number=ack_number, poll=poll,
port=port, information=information, length=len(information)+6,
discriminator=None, padding=b''))
def build_reliable_supervisory_packet(
kind, ack_number, poll=False, final=False):
return ReliableSupervisoryPacket.build(construct.Container(
kind=kind, ack_number=ack_number, poll=poll or final,
final=None, reserved=None, discriminator=None))
class ReliableTransport(object):
'''The reliable transport protocol, also known as TRAIN.
The protocol is based on LAPB from ITU-T Recommendation X.25.
'''
NCP_PROTOCOL_NUMBER = 0xBA33
COMMAND_PROTOCOL_NUMBER = 0x3A33
RESPONSE_PROTOCOL_NUMBER = 0x3A35
MODULUS = 128
max_retransmits = 10 # N2 system parameter in LAPB
retransmit_timeout = 0.2 # T1 system parameter
def __init__(self, interface, link_mtu):
self.logger = pulse2_logging.TaggedAdapter(
logger, {'tag': type(self).__name__})
self.send_queue = queue.Queue()
self.opened = threading.Event()
self.closed = False
self.last_sent_packet = None
# The sequence number of the next in-sequence I-packet to be Tx'ed
self.send_variable = 0 # V(S) in LAPB
self.retransmit_count = 0
self.waiting_for_ack = False
self.last_ack_number = 0 # N(R) of the most recently received packet
self.transmit_lock = threading.RLock()
self.retransmit_timer = None
# The expected sequence number of the next received I-packet
self.receive_variable = 0 # V(R) in LAPB
self.sockets = {}
self._mtu = link_mtu - 6
self.command_socket = interface.connect(
self.COMMAND_PROTOCOL_NUMBER)
self.response_socket = interface.connect(
self.RESPONSE_PROTOCOL_NUMBER)
self.command_socket.on_packet = self.command_packet_received
self.response_socket.on_packet = self.response_packet_received
self.ncp = TransportControlProtocol(
interface=interface, transport=self,
ncp_protocol=self.NCP_PROTOCOL_NUMBER,
display_name='ReliableControlProtocol')
self.ncp.up()
self.ncp.open()
@property
def mtu(self):
return self._mtu
def reset_stats(self):
self.stats = {
'info_packets_sent': 0,
'info_packets_received': 0,
'retransmits': 0,
'out_of_order_packets': 0,
'round_trip_time': stats.OnlineStatistics(),
}
self.last_packet_sent_time = None
def this_layer_up(self):
self.send_variable = 0
self.receive_variable = 0
self.retransmit_count = 0
self.last_ack_number = 0
self.waiting_for_ack = False
self.reset_stats()
# We can't let PCMP bind itself using the public open_socket
# method as the method will block until self.opened is set, but
# it won't be set until the peer sends us a packet over the
# transport. But we want to bind the port without waiting.
self.pcmp = pcmp.PulseControlMessageProtocol(
self, pcmp.PulseControlMessageProtocol.PORT)
self.sockets[pcmp.PulseControlMessageProtocol.PORT] = self.pcmp
self.pcmp.on_port_closed = self.on_port_closed
# Send an RR command packet to elicit an RR response from the
# remote peer. Receiving a response from the peer confirms that
# the transport is ready to carry traffic, at which point we
# will allow applications to start opening sockets.
self.send_supervisory_command(kind='RR', poll=True)
self.start_retransmit_timer()
def this_layer_down(self):
self.opened.clear()
if self.retransmit_timer:
self.retransmit_timer.cancel()
self.retransmit_timer = None
self.close_all_sockets()
self.logger.info('Info packets sent=%d retransmits=%d',
self.stats['info_packets_sent'],
self.stats['retransmits'])
self.logger.info('Info packets received=%d out-of-order=%d',
self.stats['info_packets_received'],
self.stats['out_of_order_packets'])
self.logger.info('Round-trip %s ms', self.stats['round_trip_time'])
def open_socket(self, port, timeout=30.0, factory=Socket):
if self.closed:
raise ValueError('Cannot open socket on closed transport')
if port in self.sockets and not self.sockets[port].closed:
raise KeyError('Another socket is already opened '
'on port 0x%04x' % port)
if not self.opened.wait(timeout):
return None
socket = factory(self, port)
self.sockets[port] = socket
return socket
def unregister_socket(self, port):
del self.sockets[port]
def down(self):
self.closed = True
self.close_all_sockets()
self.command_socket.close()
self.response_socket.close()
self.ncp.down()
def close_all_sockets(self):
for socket in list(self.sockets.values()):
socket.close()
self.sockets = {}
def on_port_closed(self, closed_port):
self.logger.info('Remote peer says port 0x%04X is closed; '
'closing socket', closed_port)
try:
self.sockets[closed_port].close()
except KeyError:
self.logger.exception('No socket is open on port 0x%04X!',
closed_port)
def _send_info_packet(self, port, information):
packet = build_reliable_info_packet(
sequence_number=self.send_variable,
ack_number=self.receive_variable,
poll=True, port=port, information=information)
self.command_socket.send(packet)
self.stats['info_packets_sent'] += 1
self.last_packet_sent_time = time.time()
def send(self, port, information):
if self.closed:
raise exceptions.TransportNotReady(
'I/O operation on closed transport')
if not self.opened.is_set():
raise exceptions.TransportNotReady(
'Attempted to send a packet while the reliable transport '
'is not open')
if len(information) > self.mtu:
raise ValueError('Packet length (%d) exceeds transport MTU (%d)' % (
len(information), self.mtu))
self.send_queue.put((port, information))
self.pump_send_queue()
def process_ack(self, ack_number):
with self.transmit_lock:
if not self.waiting_for_ack:
# Could be in the timer recovery condition (waiting for
# a response to an RR Poll command). This is a bit
# hacky and should probably be changed to use an
# explicit state machine when this transport is
# extended to support Go-Back-N ARQ.
if self.retransmit_timer:
self.retransmit_timer.cancel()
self.retransmit_timer = None
self.retransmit_count = 0
if (ack_number - 1) % self.MODULUS == self.send_variable:
if self.retransmit_timer:
self.retransmit_timer.cancel()
self.retransmit_timer = None
self.retransmit_count = 0
self.waiting_for_ack = False
self.send_variable = (self.send_variable + 1) % self.MODULUS
if self.last_packet_sent_time:
self.stats['round_trip_time'].update(
(time.time() - self.last_packet_sent_time) * 1000)
def pump_send_queue(self):
with self.transmit_lock:
if not self.waiting_for_ack:
try:
port, information = self.send_queue.get_nowait()
self.last_sent_packet = (port, information)
self.waiting_for_ack = True
self._send_info_packet(port, information)
self.start_retransmit_timer()
except queue.Empty:
pass
def start_retransmit_timer(self):
if self.retransmit_timer:
self.retransmit_timer.cancel()
self.retransmit_timer = threading.Timer(
self.retransmit_timeout,
self.retransmit_timeout_expired)
self.retransmit_timer.daemon = True
self.retransmit_timer.start()
def retransmit_timeout_expired(self):
with self.transmit_lock:
self.retransmit_count += 1
if self.retransmit_count < self.max_retransmits:
self.stats['retransmits'] += 1
if self.last_sent_packet:
self._send_info_packet(*self.last_sent_packet)
else:
# No info packet to retransmit; must be an RR command
# that needs to be retransmitted.
self.send_supervisory_command(kind='RR', poll=True)
self.start_retransmit_timer()
else:
self.logger.warning('Reached maximum number of retransmit '
'attempts')
self.ncp.restart()
def send_supervisory_command(self, kind, poll=False):
with self.transmit_lock:
command = build_reliable_supervisory_packet(
kind=kind, poll=poll, ack_number=self.receive_variable)
self.command_socket.send(command)
def send_supervisory_response(self, kind, final=False):
with self.transmit_lock:
response = build_reliable_supervisory_packet(
kind=kind, final=final, ack_number=self.receive_variable)
self.response_socket.send(response)
def command_packet_received(self, packet):
if not self.ncp.is_Opened():
self.logger.warning('Received command packet before transport '
'is open. Discarding.')
return
# Information packets have the LSBit of the first byte cleared.
is_info = (bytearray(packet[0])[0] & 0b1) == 0
try:
if is_info:
fields = ReliableInfoPacket.parse(packet)
else:
fields = ReliableSupervisoryPacket.parse(packet)
except (construct.ConstructError, ValueError):
self.logger.exception('Received malformed command packet')
self.ncp.restart()
return
self.opened.set()
if is_info:
if fields.sequence_number == self.receive_variable:
self.receive_variable = (
self.receive_variable + 1) % self.MODULUS
self.stats['info_packets_received'] += 1
if len(fields.information) + 6 == fields.length:
if fields.port in self.sockets:
self.sockets[fields.port].on_receive(
fields.information)
else:
self.logger.warning(
'Received packet on closed port %04X',
fields.port)
else:
self.logger.error(
'Received truncated or corrupt info packet '
'(expected %d data bytes, got %d)',
fields.length-6, len(fields.information))
else:
self.stats['out_of_order_packets'] += 1
self.send_supervisory_response(kind='RR', final=fields.poll)
else:
if fields.kind not in ('RR', 'REJ'):
self.logger.error('Received a %s command packet, which is not '
'yet supported by this implementation',
fields.kind)
# Pretend it is an RR packet
self.process_ack(fields.ack_number)
if fields.poll:
self.send_supervisory_response(kind='RR', final=True)
self.pump_send_queue()
def response_packet_received(self, packet):
if not self.ncp.is_Opened():
self.logger.error(
'Received response packet before transport is open. '
'Discarding.')
return
# Information packets cannot be responses; we only need to
# handle receiving Supervisory packets.
try:
fields = ReliableSupervisoryPacket.parse(packet)
except (construct.ConstructError, ValueError):
self.logger.exception('Received malformed response packet')
self.ncp.restart()
return
self.opened.set()
self.process_ack(fields.ack_number)
self.pump_send_queue()
if fields.kind not in ('RR', 'REJ'):
self.logger.error('Received a %s response packet, which is not '
'yet supported by this implementation.',
fields.kind)

View file

@ -0,0 +1,60 @@
# Copyright 2024 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# Always prefer setuptools over distutils
from setuptools import setup, find_packages
# To use a consistent encoding
from codecs import open
from os import path
import sys
here = path.abspath(path.dirname(__file__))
# Get the long description from the README file
with open(path.join(here, 'README.rst'), encoding='utf-8') as f:
long_description = f.read()
requires = [
'cobs',
'construct>=2.5.3,<2.8',
'pyserial>=2.7,<3',
'transitions>=0.4.0',
]
test_requires = []
if sys.version_info < (3, 3, 0):
test_requires.append('mock>=2.0.0')
if sys.version_info < (3, 4, 0):
requires.append('enum34')
setup(
name='pebble.pulse2',
version='0.0.7',
description='Python tools for connecting to PULSEv2 links',
long_description=long_description,
url='https://github.com/pebble/pulse2',
author='Pebble Technology Corporation',
author_email='cory@pebble.com',
packages=find_packages(exclude=['contrib', 'docs', 'tests']),
namespace_packages = ['pebble'],
install_requires=requires,
extras_require={
'test': test_requires,
},
test_suite = 'tests',
)

View file

@ -0,0 +1,14 @@
# Copyright 2024 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

View file

@ -0,0 +1,60 @@
# Copyright 2024 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
class FakeTimer(object):
TIMERS = []
def __init__(self, interval, function):
self.interval = interval
self.function = function
self.started = False
self.expired = False
self.cancelled = False
type(self).TIMERS.append(self)
def __repr__(self):
state_flags = ''.join([
'S' if self.started else 'N',
'X' if self.expired else '.',
'C' if self.cancelled else '.'])
return '<FakeTimer({}, {}) {} at {:#x}>'.format(
self.interval, self.function, state_flags, id(self))
def start(self):
if self.started:
raise RuntimeError("threads can only be started once")
self.started = True
def cancel(self):
self.cancelled = True
def expire(self):
'''Simulate the timeout expiring.'''
assert self.started, 'timer not yet started'
assert not self.expired, 'timer can only expire once'
self.expired = True
self.function()
@property
def is_active(self):
return self.started and not self.expired and not self.cancelled
@classmethod
def clear_timer_list(cls):
cls.TIMERS = []
@classmethod
def get_active_timers(cls):
return [t for t in cls.TIMERS if t.is_active]

View file

@ -0,0 +1,156 @@
# Copyright 2024 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import absolute_import
import unittest
from pebble.pulse2 import framing
class TestEncodeFrame(unittest.TestCase):
def test_empty_frame(self):
# CRC-32 of nothing is 0
# COBS encoding of b'\0\0\0\0' is b'\x01\x01\x01\x01\x01' (5 bytes)
self.assertEqual(framing.encode_frame(b''),
b'\x55\x01\x01\x01\x01\x01\x55')
def test_simple_data(self):
self.assertEqual(framing.encode_frame(b'abcdefg'),
b'\x55\x0cabcdefg\xa6\x6a\x2a\x31\x55')
def test_flag_in_datagram(self):
# ASCII 'U' is 0x55 hex
self.assertEqual(framing.encode_frame(b'QUACK'),
b'\x55\x0aQ\0ACK\xdf\x8d\x80\x74\x55')
def test_flag_in_fcs(self):
# crc32(b'R') -> 0x5767df55
# Since there is an \x55 byte in the FCS, it must be substituted,
# just like when that byte value is present in the datagram itself.
self.assertEqual(framing.encode_frame(b'R'),
b'\x55\x06R\0\xdf\x67\x57\x55')
class TestFrameSplitter(unittest.TestCase):
def setUp(self):
self.splitter = framing.FrameSplitter()
def test_basic_functionality(self):
self.splitter.write(b'\x55abcdefg\x55foobar\x55asdf\x55')
self.assertEqual(list(self.splitter),
[b'abcdefg', b'foobar', b'asdf'])
def test_wait_for_sync(self):
self.splitter.write(b'garbage data\x55frame 1\x55')
self.assertEqual(list(self.splitter), [b'frame 1'])
def test_doubled_flags(self):
self.splitter.write(b'\x55abcd\x55\x55efgh\x55')
self.assertEqual(list(self.splitter), [b'abcd', b'efgh'])
def test_multiple_writes(self):
self.splitter.write(b'\x55ab')
self.assertEqual(list(self.splitter), [])
self.splitter.write(b'cd\x55')
self.assertEqual(list(self.splitter), [b'abcd'])
def test_lots_of_writes(self):
for char in b'\x55abcd\x55ef':
self.splitter.write(bytearray([char]))
self.assertEqual(list(self.splitter), [b'abcd'])
def test_iteration_pops_frames(self):
self.splitter.write(b'\x55frame 1\x55frame 2\x55frame 3\x55')
self.assertEqual(next(iter(self.splitter)), b'frame 1')
self.assertEqual(list(self.splitter), [b'frame 2', b'frame 3'])
def test_stopiteration_latches(self):
# The iterator protocol requires that once an iterator raises
# StopIteration, it must continue to do so for all subsequent calls
# to its next() method.
self.splitter.write(b'\x55frame 1\x55')
iterator = iter(self.splitter)
self.assertEqual(next(iterator), b'frame 1')
with self.assertRaises(StopIteration):
next(iterator)
next(iterator)
self.splitter.write(b'\x55frame 2\x55')
with self.assertRaises(StopIteration):
next(iterator)
self.assertEqual(list(self.splitter), [b'frame 2'])
def test_max_frame_length(self):
splitter = framing.FrameSplitter(max_frame_length=6)
splitter.write(
b'\x5512345\x55123456\x551234567\x551234\x5512345678\x55')
self.assertEqual(list(splitter), [b'12345', b'123456', b'1234'])
def test_dynamic_max_length_1(self):
self.splitter.write(b'\x5512345')
self.splitter.max_frame_length = 6
self.splitter.write(b'6\x551234567\x551234\x55')
self.assertEqual(list(self.splitter), [b'123456', b'1234'])
def test_dynamic_max_length_2(self):
self.splitter.write(b'\x551234567')
self.splitter.max_frame_length = 6
self.splitter.write(b'89\x55123456\x55')
self.assertEqual(list(self.splitter), [b'123456'])
class TestDecodeTransparency(unittest.TestCase):
def test_easy_decode(self):
self.assertEqual(framing.decode_transparency(b'\x06abcde'), b'abcde')
def test_escaped_flag(self):
self.assertEqual(framing.decode_transparency(b'\x06Q\0ACK'), b'QUACK')
def test_flag_byte_in_frame(self):
with self.assertRaises(framing.DecodeError):
framing.decode_transparency(b'\x06ab\x55de')
def test_truncated_cobs_block(self):
with self.assertRaises(framing.DecodeError):
framing.decode_transparency(b'\x0aabc')
class TestStripFCS(unittest.TestCase):
def test_frame_too_short(self):
with self.assertRaises(framing.CorruptFrame):
framing.strip_fcs(b'abcd')
def test_good_fcs(self):
self.assertEqual(framing.strip_fcs(b'abcd\x11\xcd\x82\xed'), b'abcd')
def test_frame_corrupted(self):
with self.assertRaises(framing.CorruptFrame):
framing.strip_fcs(b'abce\x11\xcd\x82\xed')
def test_fcs_corrupted(self):
with self.assertRaises(framing.CorruptFrame):
framing.strip_fcs(b'abcd\x13\xcd\x82\xed')
class TestDecodeFrame(unittest.TestCase):
def test_it_works(self):
# Not much to test; decode_frame is just chained decode_transparency
# with strip_fcs, and both of those have already been tested separately.
self.assertEqual(framing.decode_frame(b'\x0aQ\0ACK\xdf\x8d\x80t'),
b'QUACK')

View file

@ -0,0 +1,261 @@
# Copyright 2024 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import absolute_import
import threading
import time
import unittest
try:
from unittest import mock
except ImportError:
import mock
try:
import queue
except ImportError:
import Queue as queue
from pebble.pulse2 import exceptions, framing, link, ppp
class FakeIOStream(object):
def __init__(self):
self.read_queue = queue.Queue()
self.write_queue = queue.Queue()
self.closed = False
def read(self, length):
if self.closed:
raise IOError('I/O operation on closed FakeIOStream')
try:
return self.read_queue.get(timeout=0.001)
except queue.Empty:
return b''
def write(self, data):
if self.closed:
raise IOError('I/O operation on closed FakeIOStream')
self.write_queue.put(data)
def close(self):
self.closed = True
def pop_all_written_data(self):
data = []
try:
while True:
data.append(self.write_queue.get_nowait())
except queue.Empty:
pass
return data
class TestInterface(unittest.TestCase):
def setUp(self):
self.iostream = FakeIOStream()
self.uut = link.Interface(self.iostream)
self.addCleanup(self.iostream.close)
# Speed up test execution by overriding the LCP timeout
self.uut.lcp.restart_timeout = 0.001
self.uut.lcp.ping = self.fake_ping
self.ping_should_succeed = True
def fake_ping(self, cb, *args, **kwargs):
cb(self.ping_should_succeed)
def test_send_packet(self):
self.uut.send_packet(0x8889, b'data')
self.assertIn(framing.encode_frame(ppp.encapsulate(0x8889, b'data')),
self.iostream.pop_all_written_data())
def test_connect_returns_socket(self):
self.assertIsNotNone(self.uut.connect(0xf0f1))
def test_send_from_socket(self):
socket = self.uut.connect(0xf0f1)
socket.send(b'data')
self.assertIn(framing.encode_frame(ppp.encapsulate(0xf0f1, b'data')),
self.iostream.pop_all_written_data())
def test_interface_closing_closes_sockets_and_iostream(self):
socket1 = self.uut.connect(0xf0f1)
socket2 = self.uut.connect(0xf0f3)
self.uut.close()
self.assertTrue(socket1.closed)
self.assertTrue(socket2.closed)
self.assertTrue(self.iostream.closed)
def test_iostream_closing_closes_interface_and_sockets(self):
socket = self.uut.connect(0xf0f1)
self.iostream.close()
time.sleep(0.01) # Wait for receive thread to notice
self.assertTrue(self.uut.closed)
self.assertTrue(socket.closed)
def test_opening_two_sockets_on_same_protocol_is_an_error(self):
socket1 = self.uut.connect(0xf0f1)
with self.assertRaisesRegexp(ValueError, 'socket is already bound'):
socket2 = self.uut.connect(0xf0f1)
def test_closing_socket_allows_another_to_be_opened(self):
socket1 = self.uut.connect(0xf0f1)
socket1.close()
socket2 = self.uut.connect(0xf0f1)
self.assertIsNot(socket1, socket2)
def test_sending_from_closed_interface_is_an_error(self):
self.uut.close()
with self.assertRaisesRegexp(ValueError, 'closed interface'):
self.uut.send_packet(0x8889, b'data')
def test_get_link_returns_None_when_lcp_is_down(self):
self.assertIsNone(self.uut.get_link(timeout=0))
def test_get_link_from_closed_interface_is_an_error(self):
self.uut.close()
with self.assertRaisesRegexp(ValueError, 'closed interface'):
self.uut.get_link(timeout=0)
def test_get_link_when_lcp_is_up(self):
self.uut.on_link_up()
self.assertIsNotNone(self.uut.get_link(timeout=0))
def test_link_object_is_closed_when_lcp_goes_down(self):
self.uut.on_link_up()
link = self.uut.get_link(timeout=0)
self.assertFalse(link.closed)
self.uut.on_link_down()
self.assertTrue(link.closed)
def test_lcp_bouncing_doesnt_reopen_old_link_object(self):
self.uut.on_link_up()
link1 = self.uut.get_link(timeout=0)
self.uut.on_link_down()
self.uut.on_link_up()
link2 = self.uut.get_link(timeout=0)
self.assertTrue(link1.closed)
self.assertFalse(link2.closed)
def test_close_gracefully_shuts_down_lcp(self):
self.uut.lcp.receive_configure_request_acceptable(0, b'')
self.uut.lcp.receive_configure_ack()
self.uut.close()
self.assertTrue(self.uut.lcp.is_finished.is_set())
def test_ping_failure_triggers_lcp_restart(self):
self.ping_should_succeed = False
self.uut.lcp.restart = mock.Mock()
self.uut.on_link_up()
self.assertIsNone(self.uut.get_link(timeout=0))
self.uut.lcp.restart.assert_called_once_with()
class TestInterfaceSocket(unittest.TestCase):
def setUp(self):
self.interface = mock.MagicMock()
self.uut = link.InterfaceSocket(self.interface, 0xf2f1)
def test_socket_is_not_closed_when_constructed(self):
self.assertFalse(self.uut.closed)
def test_send(self):
self.uut.send(b'data')
self.interface.send_packet.assert_called_once_with(0xf2f1, b'data')
def test_close_sets_socket_as_closed(self):
self.uut.close()
self.assertTrue(self.uut.closed)
def test_close_unregisters_socket_with_interface(self):
self.uut.close()
self.interface.unregister_socket.assert_called_once_with(0xf2f1)
def test_close_calls_on_close_handler(self):
on_close = mock.Mock()
self.uut.on_close = on_close
self.uut.close()
on_close.assert_called_once_with()
def test_send_after_close_is_an_error(self):
self.uut.close()
with self.assertRaises(exceptions.SocketClosed):
self.uut.send(b'data')
def test_handle_packet(self):
self.uut.on_packet = mock.Mock()
self.uut.handle_packet(b'data')
self.uut.on_packet.assert_called_once_with(b'data')
def test_handle_packet_does_not_call_on_packet_handler_after_close(self):
on_packet = mock.Mock()
self.uut.on_packet = on_packet
self.uut.close()
self.uut.handle_packet(b'data')
on_packet.assert_not_called()
def test_context_manager(self):
with self.uut as uut:
self.assertIs(self.uut, uut)
self.assertFalse(self.uut.closed)
self.assertTrue(self.uut.closed)
def test_close_is_idempotent(self):
on_close = mock.Mock()
self.uut.on_close = on_close
self.uut.close()
self.uut.close()
self.assertEqual(1, self.interface.unregister_socket.call_count)
self.assertEqual(1, on_close.call_count)
class TestLink(unittest.TestCase):
def setUp(self):
transports_patcher = mock.patch.dict(
link.Link.TRANSPORTS, {'fake': mock.Mock()}, clear=True)
transports_patcher.start()
self.addCleanup(transports_patcher.stop)
self.uut = link.Link(mock.Mock(), 1500)
def test_open_socket(self):
socket = self.uut.open_socket(
transport='fake', port=0xabcd, timeout=1.0)
self.uut.transports['fake'].open_socket.assert_called_once_with(
0xabcd, 1.0)
self.assertIs(socket, self.uut.transports['fake'].open_socket())
def test_down(self):
self.uut.down()
self.assertTrue(self.uut.closed)
self.uut.transports['fake'].down.assert_called_once_with()
def test_on_close_callback_when_going_down(self):
self.uut.on_close = mock.Mock()
self.uut.down()
self.uut.on_close.assert_called_once_with()
def test_open_socket_after_down_is_an_error(self):
self.uut.down()
with self.assertRaisesRegexp(ValueError, 'closed Link'):
self.uut.open_socket('fake', 0xabcd)
def test_open_socket_with_bad_transport_name(self):
with self.assertRaisesRegexp(KeyError, "Unknown transport 'bad'"):
self.uut.open_socket('bad', 0xabcd)

View file

@ -0,0 +1,168 @@
# Copyright 2024 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import unittest
try:
from unittest import mock
except ImportError:
import mock
from pebble.pulse2 import pcmp
from .fake_timer import FakeTimer
class TestPCMP(unittest.TestCase):
def setUp(self):
self.uut = pcmp.PulseControlMessageProtocol(mock.Mock(), 1)
def test_close_unregisters_the_socket(self):
self.uut.close()
self.uut.transport.unregister_socket.assert_called_once_with(1)
def test_close_is_idempotent(self):
self.uut.close()
self.uut.close()
self.assertEqual(1, self.uut.transport.unregister_socket.call_count)
def test_send_unknown_code(self):
self.uut.send_unknown_code(42)
self.uut.transport.send.assert_called_once_with(1, b'\x82\x2a')
def test_send_echo_request(self):
self.uut.send_echo_request(b'abcdefg')
self.uut.transport.send.assert_called_once_with(1, b'\x01abcdefg')
def test_send_echo_reply(self):
self.uut.send_echo_reply(b'abcdefg')
self.uut.transport.send.assert_called_once_with(1, b'\x02abcdefg')
def test_on_receive_empty_packet(self):
self.uut.on_receive(b'')
self.uut.transport.send.assert_not_called()
def test_on_receive_message_with_unknown_code(self):
self.uut.on_receive(b'\x00')
self.uut.transport.send.assert_called_once_with(1, b'\x82\x00')
def test_on_receive_malformed_unknown_code_message_1(self):
self.uut.on_receive(b'\x82')
self.uut.transport.send.assert_not_called()
def test_on_receive_malformed_unknown_code_message_2(self):
self.uut.on_receive(b'\x82\x00\x01')
self.uut.transport.send.assert_not_called()
def test_on_receive_discard_request(self):
self.uut.on_receive(b'\x03')
self.uut.transport.send.assert_not_called()
def test_on_receive_discard_request_with_data(self):
self.uut.on_receive(b'\x03asdfasdfasdf')
self.uut.transport.send.assert_not_called()
def test_on_receive_echo_request(self):
self.uut.on_receive(b'\x01')
self.uut.transport.send.assert_called_once_with(1, b'\x02')
def test_on_receive_echo_request_with_data(self):
self.uut.on_receive(b'\x01a')
self.uut.transport.send.assert_called_once_with(1, b'\x02a')
def test_on_receive_echo_reply(self):
self.uut.on_receive(b'\x02')
self.uut.transport.send.assert_not_called()
def test_on_receive_echo_reply_with_data(self):
self.uut.on_receive(b'\x02abc')
self.uut.transport.send.assert_not_called()
def test_on_receive_port_closed_with_no_handler(self):
self.uut.on_receive(b'\x81\xab\xcd')
self.uut.transport.send.assert_not_called()
def test_on_receive_port_closed(self):
self.uut.on_port_closed = mock.Mock()
self.uut.on_receive(b'\x81\xab\xcd')
self.uut.on_port_closed.assert_called_once_with(0xabcd)
def test_on_receive_malformed_port_closed_message_1(self):
self.uut.on_port_closed = mock.Mock()
self.uut.on_receive(b'\x81\xab')
self.uut.on_port_closed.assert_not_called()
def test_on_receive_malformed_port_closed_message_2(self):
self.uut.on_port_closed = mock.Mock()
self.uut.on_receive(b'\x81\xab\xcd\xef')
self.uut.on_port_closed.assert_not_called()
class TestPing(unittest.TestCase):
def setUp(self):
FakeTimer.clear_timer_list()
timer_patcher = mock.patch('threading.Timer', new=FakeTimer)
timer_patcher.start()
self.addCleanup(timer_patcher.stop)
self.uut = pcmp.PulseControlMessageProtocol(mock.Mock(), 1)
def test_successful_ping(self):
cb = mock.Mock()
self.uut.ping(cb)
self.uut.on_receive(b'\x02')
cb.assert_called_once_with(True)
self.assertFalse(FakeTimer.get_active_timers())
def test_ping_succeeds_after_retry(self):
cb = mock.Mock()
self.uut.ping(cb, attempts=2)
FakeTimer.TIMERS[-1].expire()
self.uut.on_receive(b'\x02')
cb.assert_called_once_with(True)
self.assertFalse(FakeTimer.get_active_timers())
def test_ping_succeeds_after_multiple_retries(self):
cb = mock.Mock()
self.uut.ping(cb, attempts=3)
timer1 = FakeTimer.TIMERS[-1]
timer1.expire()
timer2 = FakeTimer.TIMERS[-1]
self.assertIsNot(timer1, timer2)
timer2.expire()
self.uut.on_receive(b'\x02')
cb.assert_called_once_with(True)
self.assertFalse(FakeTimer.get_active_timers())
def test_failed_ping(self):
cb = mock.Mock()
self.uut.ping(cb, attempts=1)
FakeTimer.TIMERS[-1].expire()
cb.assert_called_once_with(False)
self.assertFalse(FakeTimer.get_active_timers())
def test_ping_fails_after_multiple_retries(self):
cb = mock.Mock()
self.uut.ping(cb, attempts=3)
for _ in range(3):
FakeTimer.TIMERS[-1].expire()
cb.assert_called_once_with(False)
self.assertFalse(FakeTimer.get_active_timers())
def test_socket_close_aborts_ping(self):
cb = mock.Mock()
self.uut.ping(cb, attempts=3)
self.uut.close()
cb.assert_not_called()
self.assertFalse(FakeTimer.get_active_timers())

View file

@ -0,0 +1,589 @@
# Copyright 2024 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import absolute_import
import unittest
try:
from unittest import mock
except ImportError:
import mock
import construct
from pebble.pulse2 import ppp, exceptions
from .fake_timer import FakeTimer
from . import timer_helper
class TestPPPEncapsulation(unittest.TestCase):
def test_ppp_encapsulate(self):
self.assertEqual(ppp.encapsulate(0xc021, b'Information'),
b'\xc0\x21Information')
class TestPPPUnencapsulate(unittest.TestCase):
def test_ppp_unencapsulate(self):
protocol, information = ppp.unencapsulate(b'\xc0\x21Information')
self.assertEqual((protocol, information), (0xc021, b'Information'))
def test_unencapsulate_empty_frame(self):
with self.assertRaises(ppp.UnencapsulationError):
ppp.unencapsulate(b'')
def test_unencapsulate_too_short_frame(self):
with self.assertRaises(ppp.UnencapsulationError):
ppp.unencapsulate(b'\x21')
def test_unencapsulate_empty_information(self):
protocol, information = ppp.unencapsulate(b'\xc0\x21')
self.assertEqual((protocol, information), (0xc021, b''))
class TestConfigurationOptionsParser(unittest.TestCase):
def test_no_options(self):
options = ppp.OptionList.parse(b'')
self.assertEqual(len(options), 0)
def test_one_empty_option(self):
options = ppp.OptionList.parse(b'\xaa\x02')
self.assertEqual(len(options), 1)
self.assertEqual(options[0].type, 0xaa)
self.assertEqual(options[0].data, b'')
def test_one_option_with_length(self):
options = ppp.OptionList.parse(b'\xab\x07Data!')
self.assertEqual((0xab, b'Data!'), options[0])
def test_multiple_options_empty_first(self):
options = ppp.OptionList.parse(b'\x22\x02\x23\x03a\x21\x04ab')
self.assertEqual([(0x22, b''), (0x23, b'a'), (0x21, b'ab')], options)
def test_multiple_options_dataful_first(self):
options = ppp.OptionList.parse(b'\x31\x08option\x32\x02')
self.assertEqual([(0x31, b'option'), (0x32, b'')], options)
def test_option_with_length_too_short(self):
with self.assertRaises(ppp.ParseError):
ppp.OptionList.parse(b'\x41\x01')
def test_option_list_with_malformed_option(self):
with self.assertRaises(ppp.ParseError):
ppp.OptionList.parse(b'\x0a\x02\x0b\x01\x0c\x03a')
def test_truncated_terminal_option(self):
with self.assertRaises(ppp.ParseError):
ppp.OptionList.parse(b'\x61\x02\x62\x03a\x63\x0ccandleja')
class TestConfigurationOptionsBuilder(unittest.TestCase):
def test_no_options(self):
serialized = ppp.OptionList.build([])
self.assertEqual(b'', serialized)
def test_one_empty_option(self):
serialized = ppp.OptionList.build([ppp.Option(0xaa, b'')])
self.assertEqual(b'\xaa\x02', serialized)
def test_one_option_with_length(self):
serialized = ppp.OptionList.build([ppp.Option(0xbb, b'Data!')])
self.assertEqual(b'\xbb\x07Data!', serialized)
def test_two_options(self):
serialized = ppp.OptionList.build([
ppp.Option(0xcc, b'foo'), ppp.Option(0xdd, b'xyzzy')])
self.assertEqual(b'\xcc\x05foo\xdd\x07xyzzy', serialized)
class TestLCPEnvelopeParsing(unittest.TestCase):
def test_packet_no_padding(self):
parsed = ppp.LCPEncapsulation.parse(b'\x01\xab\x00\x0aabcdef')
self.assertEqual(parsed.code, 1)
self.assertEqual(parsed.identifier, 0xab)
self.assertEqual(parsed.data, b'abcdef')
self.assertEqual(parsed.padding, b'')
def test_padding(self):
parsed = ppp.LCPEncapsulation.parse(b'\x01\xab\x00\x0aabcdefpadding')
self.assertEqual(parsed.data, b'abcdef')
self.assertEqual(parsed.padding, b'padding')
def test_truncated_packet(self):
with self.assertRaises(ppp.ParseError):
ppp.LCPEncapsulation.parse(b'\x01\xab\x00\x0aabcde')
def test_bogus_length(self):
with self.assertRaises(ppp.ParseError):
ppp.LCPEncapsulation.parse(b'\x01\xbc\x00\x03')
def test_empty_data(self):
parsed = ppp.LCPEncapsulation.parse(b'\x03\x01\x00\x04')
self.assertEqual((3, 1, b'', b''), parsed)
class TestLCPEnvelopeBuilder(unittest.TestCase):
def test_build_empty_data(self):
serialized = ppp.LCPEncapsulation.build(1, 0xfe, b'')
self.assertEqual(b'\x01\xfe\x00\x04', serialized)
def test_build_with_data(self):
serialized = ppp.LCPEncapsulation.build(3, 0x2a, b'Hello, world!')
self.assertEqual(b'\x03\x2a\x00\x11Hello, world!', serialized)
class TestProtocolRejectParsing(unittest.TestCase):
def test_protocol_and_info(self):
self.assertEqual((0xabcd, b'asdfasdf'),
ppp.ProtocolReject.parse(b'\xab\xcdasdfasdf'))
def test_empty_info(self):
self.assertEqual((0xf00d, b''),
ppp.ProtocolReject.parse(b'\xf0\x0d'))
def test_truncated_packet(self):
with self.assertRaises(ppp.ParseError):
ppp.ProtocolReject.parse(b'\xab')
class TestMagicNumberAndDataParsing(unittest.TestCase):
def test_magic_and_data(self):
self.assertEqual(
(0xabcdef01, b'datadata'),
ppp.MagicNumberAndData.parse(b'\xab\xcd\xef\x01datadata'))
def test_magic_no_data(self):
self.assertEqual(
(0xfeedface, b''),
ppp.MagicNumberAndData.parse(b'\xfe\xed\xfa\xce'))
def test_truncated_packet(self):
with self.assertRaises(ppp.ParseError):
ppp.MagicNumberAndData.parse(b'abc')
class TestMagicNumberAndDataBuilder(unittest.TestCase):
def test_build_empty_data(self):
serialized = ppp.MagicNumberAndData.build(0x12345678, b'')
self.assertEqual(b'\x12\x34\x56\x78', serialized)
def test_build_with_data(self):
serialized = ppp.MagicNumberAndData.build(0xabcdef01, b'foobar')
self.assertEqual(b'\xab\xcd\xef\x01foobar', serialized)
def test_build_with_named_attributes(self):
serialized = ppp.MagicNumberAndData.build(magic_number=0, data=b'abc')
self.assertEqual(b'\0\0\0\0abc', serialized)
class TestControlProtocolRestartTimer(unittest.TestCase):
def setUp(self):
FakeTimer.clear_timer_list()
timer_patcher = mock.patch('threading.Timer', new=FakeTimer)
timer_patcher.start()
self.addCleanup(timer_patcher.stop)
self.uut = ppp.ControlProtocol()
self.uut.timeout_retry = mock.Mock()
self.uut.timeout_giveup = mock.Mock()
self.uut.restart_count = 5
def test_timeout_event_called_if_generation_ids_match(self):
self.uut.restart_timer_expired(self.uut.restart_timer_generation_id)
self.uut.timeout_retry.assert_called_once_with()
def test_timeout_event_not_called_if_generation_ids_mismatch(self):
self.uut.restart_timer_expired(42)
self.uut.timeout_retry.assert_not_called()
self.uut.timeout_giveup.assert_not_called()
def test_timeout_event_not_called_after_stopped(self):
self.uut.start_restart_timer(1)
self.uut.stop_restart_timer()
FakeTimer.TIMERS[-1].expire()
self.uut.timeout_retry.assert_not_called()
self.uut.timeout_giveup.assert_not_called()
def test_timeout_event_not_called_from_old_timer_after_restart(self):
self.uut.start_restart_timer(1)
zombie_timer = FakeTimer.get_active_timers()[-1]
self.uut.start_restart_timer(1)
zombie_timer.expire()
self.uut.timeout_retry.assert_not_called()
self.uut.timeout_giveup.assert_not_called()
def test_timeout_event_called_only_once_after_restart(self):
self.uut.start_restart_timer(1)
self.uut.start_restart_timer(1)
for timer in FakeTimer.TIMERS:
timer.expire()
self.uut.timeout_retry.assert_called_once_with()
self.uut.timeout_giveup.assert_not_called()
class InstrumentedControlProtocol(ppp.ControlProtocol):
methods_to_mock = (
'this_layer_up this_layer_down this_layer_started '
'this_layer_finished send_packet start_restart_timer '
'stop_restart_timer').split()
attributes_to_mock = ('restart_timer',)
def __init__(self):
ppp.ControlProtocol.__init__(self)
for method in self.methods_to_mock:
setattr(self, method, mock.Mock())
for attr in self.attributes_to_mock:
setattr(self, attr, mock.NonCallableMock())
class ControlProtocolTestMixin(object):
CONTROL_CODE_ENUM = ppp.ControlCode
def _map_control_code(self, code):
try:
return int(code)
except ValueError:
return self.CONTROL_CODE_ENUM[code].value
def assert_packet_sent(self, code, identifier, body=b''):
self.fsm.send_packet.assert_called_once_with(
ppp.LCPEncapsulation.build(
self._map_control_code(code), identifier, body))
self.fsm.send_packet.reset_mock()
def incoming_packet(self, code, identifier, body=b''):
self.fsm.packet_received(
ppp.LCPEncapsulation.build(self._map_control_code(code),
identifier, body))
class TestControlProtocolFSM(ControlProtocolTestMixin, unittest.TestCase):
def setUp(self):
self.addCleanup(timer_helper.cancel_all_timers)
self.fsm = InstrumentedControlProtocol()
def test_open_down(self):
self.fsm.open()
self.fsm.this_layer_started.assert_called_once_with()
self.fsm.this_layer_up.assert_not_called()
self.fsm.this_layer_down.assert_not_called()
self.fsm.this_layer_finished.assert_not_called()
def test_closed_up(self):
self.fsm.up(mock.Mock())
self.fsm.this_layer_up.assert_not_called()
self.fsm.this_layer_down.assert_not_called()
self.fsm.this_layer_started.assert_not_called()
self.fsm.this_layer_finished.assert_not_called()
def test_trivial_handshake(self):
self.fsm.open()
self.fsm.up(mock.Mock())
self.assert_packet_sent('Configure_Request', 0)
self.incoming_packet('Configure_Ack', 0)
self.incoming_packet('Configure_Request', 17)
self.assert_packet_sent('Configure_Ack', 17)
self.assertEqual('Opened', self.fsm.state)
self.assertTrue(self.fsm.this_layer_up.called)
self.assertEqual(self.fsm.restart_count, self.fsm.max_configure)
def test_terminate_cleanly(self):
self.test_trivial_handshake()
self.fsm.close()
self.fsm.this_layer_down.assert_called_once_with()
self.assert_packet_sent('Terminate_Request', 42)
def test_remote_terminate(self):
self.test_trivial_handshake()
self.incoming_packet('Terminate_Request', 42)
self.assert_packet_sent('Terminate_Ack', 42)
self.assertTrue(self.fsm.this_layer_down.called)
self.assertTrue(self.fsm.start_restart_timer.called)
self.fsm.this_layer_finished.assert_not_called()
self.fsm.restart_timer_expired(self.fsm.restart_timer_generation_id)
self.assertTrue(self.fsm.this_layer_finished.called)
self.assertEqual('Stopped', self.fsm.state)
def test_remote_rejects_configure_request_code(self):
self.fsm.open()
self.fsm.up(mock.Mock())
received_packet = self.fsm.send_packet.call_args[0][0]
self.assert_packet_sent('Configure_Request', 0)
self.incoming_packet('Code_Reject', 3, received_packet)
self.assertEqual('Stopped', self.fsm.state)
self.assertTrue(self.fsm.this_layer_finished.called)
def test_receive_extended_code(self):
self.fsm.handle_unknown_code = mock.Mock()
self.test_trivial_handshake()
self.incoming_packet(42, 11, b'Life, the universe and everything')
self.fsm.handle_unknown_code.assert_called_once_with(
42, 11, b'Life, the universe and everything')
def test_receive_unimplemented_code(self):
self.test_trivial_handshake()
self.incoming_packet(0x55, 0)
self.assert_packet_sent('Code_Reject', 0, b'\x55\0\0\x04')
def test_code_reject_truncates_rejected_packet(self):
self.test_trivial_handshake()
self.incoming_packet(0xaa, 0x20, 'a'*1496) # 1500-byte Info
self.assert_packet_sent('Code_Reject', 0,
b'\xaa\x20\x05\xdc' + b'a'*1492)
def test_code_reject_identifier_changes(self):
self.test_trivial_handshake()
self.incoming_packet(0xaa, 0)
self.assert_packet_sent('Code_Reject', 0, b'\xaa\0\0\x04')
self.incoming_packet(0xaa, 0)
self.assert_packet_sent('Code_Reject', 1, b'\xaa\0\0\x04')
# Local events: up, down, open, close
# Option negotiation: reject, nak
# Exceptional situations: catastrophic code-reject
# Restart negotiation after opening
# Remote Terminate-Req, -Ack at various points in the lifecycle
# Negotiation infinite loop
# Local side gives up on negotiation
# Corrupt packets received
class TestLCPReceiveEchoRequest(ControlProtocolTestMixin, unittest.TestCase):
CONTROL_CODE_ENUM = ppp.LCPCode
def setUp(self):
self.addCleanup(timer_helper.cancel_all_timers)
self.fsm = ppp.LinkControlProtocol(mock.Mock())
self.fsm.send_packet = mock.Mock()
self.fsm.state = 'Opened'
def send_echo_request(self, identifier=0, data=b'\0\0\0\0'):
result = self.fsm.handle_unknown_code(
ppp.LCPCode.Echo_Request.value, identifier, data)
self.assertIsNot(result, NotImplemented)
def test_echo_request_is_dropped_when_not_in_opened_state(self):
self.fsm.state = 'Ack-Sent'
self.send_echo_request()
self.fsm.send_packet.assert_not_called()
def test_echo_request_elicits_reply(self):
self.send_echo_request()
self.assert_packet_sent('Echo_Reply', 0, b'\0\0\0\0')
def test_echo_request_with_data_is_echoed_in_reply(self):
self.send_echo_request(5, b'\0\0\0\0datadata')
self.assert_packet_sent('Echo_Reply', 5, b'\0\0\0\0datadata')
def test_echo_request_missing_magic_number_field_is_dropped(self):
self.send_echo_request(data=b'')
self.fsm.send_packet.assert_not_called()
def test_echo_request_with_nonzero_magic_number_is_dropped(self):
self.send_echo_request(data=b'\0\0\0\x01')
self.fsm.send_packet.assert_not_called()
class TestLCPPing(ControlProtocolTestMixin, unittest.TestCase):
CONTROL_CODE_ENUM = ppp.LCPCode
def setUp(self):
FakeTimer.clear_timer_list()
timer_patcher = mock.patch('threading.Timer', new=FakeTimer)
timer_patcher.start()
self.addCleanup(timer_patcher.stop)
self.fsm = ppp.LinkControlProtocol(mock.Mock())
self.fsm.send_packet = mock.Mock()
self.fsm.state = 'Opened'
def respond_to_ping(self):
[echo_request_packet], _ = self.fsm.send_packet.call_args
self.assertEqual(b'\x09'[0], echo_request_packet[0])
echo_response_packet = b'\x0a' + echo_request_packet[1:]
self.fsm.packet_received(echo_response_packet)
def test_ping_when_lcp_is_not_opened_is_an_error(self):
cb = mock.Mock()
self.fsm.state = 'Ack-Rcvd'
with self.assertRaises(ppp.LinkStateError):
self.fsm.ping(cb)
cb.assert_not_called()
def test_zero_attempts_is_an_error(self):
with self.assertRaises(ValueError):
self.fsm.ping(mock.Mock(), attempts=0)
def test_negative_attempts_is_an_error(self):
with self.assertRaises(ValueError):
self.fsm.ping(mock.Mock(), attempts=-1)
def test_zero_timeout_is_an_error(self):
with self.assertRaises(ValueError):
self.fsm.ping(mock.Mock(), timeout=0)
def test_negative_timeout_is_an_error(self):
with self.assertRaises(ValueError):
self.fsm.ping(mock.Mock(), timeout=-0.1)
def test_straightforward_ping(self):
cb = mock.Mock()
self.fsm.ping(cb)
cb.assert_not_called()
self.assertEqual(1, self.fsm.send_packet.call_count)
self.respond_to_ping()
cb.assert_called_once_with(True)
def test_one_timeout_before_responding(self):
cb = mock.Mock()
self.fsm.ping(cb, attempts=2)
FakeTimer.TIMERS[-1].expire()
cb.assert_not_called()
self.assertEqual(2, self.fsm.send_packet.call_count)
self.respond_to_ping()
cb.assert_called_once_with(True)
def test_one_attempt_with_no_reply(self):
cb = mock.Mock()
self.fsm.ping(cb, attempts=1)
FakeTimer.TIMERS[-1].expire()
self.assertEqual(1, self.fsm.send_packet.call_count)
cb.assert_called_once_with(False)
def test_multiple_attempts_with_no_reply(self):
cb = mock.Mock()
self.fsm.ping(cb, attempts=2)
timer_one = FakeTimer.TIMERS[-1]
timer_one.expire()
timer_two = FakeTimer.TIMERS[-1]
self.assertIsNot(timer_one, timer_two)
timer_two.expire()
self.assertEqual(2, self.fsm.send_packet.call_count)
cb.assert_called_once_with(False)
def test_late_reply(self):
cb = mock.Mock()
self.fsm.ping(cb, attempts=1)
FakeTimer.TIMERS[-1].expire()
self.respond_to_ping()
cb.assert_called_once_with(False)
def test_this_layer_down_during_ping(self):
cb = mock.Mock()
self.fsm.ping(cb)
self.fsm.this_layer_down()
FakeTimer.TIMERS[-1].expire()
cb.assert_not_called()
def test_echo_reply_with_wrong_identifier(self):
cb = mock.Mock()
self.fsm.ping(cb, attempts=1)
[echo_request_packet], _ = self.fsm.send_packet.call_args
echo_response_packet = bytearray(echo_request_packet)
echo_response_packet[0] = 0x0a
echo_response_packet[1] += 1
self.fsm.packet_received(bytes(echo_response_packet))
cb.assert_not_called()
FakeTimer.TIMERS[-1].expire()
cb.assert_called_once_with(False)
def test_echo_reply_with_wrong_data(self):
cb = mock.Mock()
self.fsm.ping(cb, attempts=1)
[echo_request_packet], _ = self.fsm.send_packet.call_args
# Generate a syntactically valid Echo-Reply with the right
# identifier but completely different data.
identifier = bytearray(echo_request_packet)[1]
echo_response_packet = bytes(
b'\x0a' + bytearray([identifier]) +
b'\0\x26\0\0\0\0bad reply bad reply bad reply.')
self.fsm.packet_received(bytes(echo_response_packet))
cb.assert_not_called()
FakeTimer.TIMERS[-1].expire()
cb.assert_called_once_with(False)
def test_successive_pings_use_different_identifiers(self):
self.fsm.ping(mock.Mock(), attempts=1)
[echo_request_packet_1], _ = self.fsm.send_packet.call_args
identifier_1 = bytearray(echo_request_packet_1)[1]
self.respond_to_ping()
self.fsm.ping(mock.Mock(), attempts=1)
[echo_request_packet_2], _ = self.fsm.send_packet.call_args
identifier_2 = bytearray(echo_request_packet_2)[1]
self.assertNotEqual(identifier_1, identifier_2)
def test_unsolicited_echo_reply_doesnt_break_anything(self):
self.fsm.packet_received(b'\x0a\0\0\x08\0\0\0\0')
def test_malformed_echo_reply(self):
cb = mock.Mock()
self.fsm.ping(cb, attempts=1)
# Only three bytes of Magic-Number
self.fsm.packet_received(b'\x0a\0\0\x07\0\0\0')
cb.assert_not_called()
# Trying to start a second ping while the first ping is still happening
def test_starting_a_ping_while_another_is_active_is_an_error(self):
cb = mock.Mock()
self.fsm.ping(cb, attempts=1)
cb2 = mock.Mock()
with self.assertRaises(exceptions.AlreadyInProgressError):
self.fsm.ping(cb2, attempts=1)
FakeTimer.TIMERS[-1].expire()
cb.assert_called_once_with(False)
cb2.assert_not_called()
# General tests:
# - Length too short for a valid packet
# - Packet truncated (length field > packet len)
# - Packet with padding
# OptionList codes:
# 1 Configure-Request
# 2 Configure-Ack
# 3 Configure-Nak
# 4 Configure-Reject
# Raw data codes:
# 5 Terminate-Request
# 6 Terminate-Ack
# 7 Code-Reject
# 8 Protocol-Reject
# - Empty Rejected-Information field
# - Rejected-Protocol field too short
# Magic number + data codes:
# 10 Echo-Reply
# 11 Discard-Request
# 12 Identification (RFC 1570)

View file

@ -0,0 +1,538 @@
# Copyright 2024 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import absolute_import
import threading
import unittest
try:
from unittest import mock
except ImportError:
import mock
import construct
from pebble.pulse2 import exceptions, pcmp, transports
from .fake_timer import FakeTimer
from . import timer_helper
# Save a reference to the real threading.Timer for tests which need to
# use timers even while threading.Timer is patched with FakeTimer.
RealThreadingTimer = threading.Timer
class CommonTransportBeforeOpenedTestCases(object):
def test_send_raises_exception(self):
with self.assertRaises(exceptions.TransportNotReady):
self.uut.send(0xdead, b'not gonna get through')
def test_open_socket_returns_None_when_ncp_fails_to_open(self):
self.assertIsNone(self.uut.open_socket(0xbeef, timeout=0))
class CommonTransportTestCases(object):
def test_send_raises_exception_after_transport_is_closed(self):
self.uut.down()
with self.assertRaises(exceptions.TransportNotReady):
self.uut.send(0xaaaa, b'asdf')
def test_socket_is_closed_when_transport_is_closed(self):
socket = self.uut.open_socket(0xabcd, timeout=0)
self.uut.down()
self.assertTrue(socket.closed)
with self.assertRaises(exceptions.SocketClosed):
socket.send(b'foo')
def test_opening_two_sockets_on_same_port_is_an_error(self):
socket1 = self.uut.open_socket(0xabcd, timeout=0)
with self.assertRaises(KeyError):
socket2 = self.uut.open_socket(0xabcd, timeout=0)
def test_closing_a_socket_allows_another_to_be_opened(self):
socket1 = self.uut.open_socket(0xabcd, timeout=0)
socket1.close()
socket2 = self.uut.open_socket(0xabcd, timeout=0)
def test_opening_socket_fails_after_transport_down(self):
self.uut.this_layer_down()
self.assertIsNone(self.uut.open_socket(0xabcd, timeout=0))
def test_opening_socket_succeeds_after_transport_bounces(self):
self.uut.this_layer_down()
self.uut.this_layer_up()
self.uut.open_socket(0xabcd, timeout=0)
class TestBestEffortTransportBeforeOpened(CommonTransportBeforeOpenedTestCases,
unittest.TestCase):
def setUp(self):
control_protocol_patcher = mock.patch(
'pebble.pulse2.transports.TransportControlProtocol')
control_protocol_patcher.start()
self.addCleanup(control_protocol_patcher.stop)
self.uut = transports.BestEffortApplicationTransport(
interface=mock.MagicMock(), link_mtu=1500)
self.uut.ncp.is_Opened.return_value = False
def test_open_socket_waits_for_ncp_to_open(self):
self.uut.ncp.is_Opened.return_value = True
def on_ping(cb, *args):
self.uut.packet_received(transports.BestEffortPacket.build(
construct.Container(port=0x0001, length=5,
information=b'\x02', padding=b'')))
cb(True)
with mock.patch.object(pcmp.PulseControlMessageProtocol, 'ping') \
as mock_ping:
mock_ping.side_effect = on_ping
open_thread = RealThreadingTimer(0.01, self.uut.this_layer_up)
open_thread.daemon = True
open_thread.start()
self.assertIsNotNone(self.uut.open_socket(0xbeef, timeout=0.5))
open_thread.join()
class TestBestEffortTransport(CommonTransportTestCases, unittest.TestCase):
def setUp(self):
self.addCleanup(timer_helper.cancel_all_timers)
self.uut = transports.BestEffortApplicationTransport(
interface=mock.MagicMock(), link_mtu=1500)
self.uut.ncp.receive_configure_request_acceptable(0, [])
self.uut.ncp.receive_configure_ack()
self.uut.packet_received(transports.BestEffortPacket.build(
construct.Container(port=0x0001, length=5,
information=b'\x02', padding=b'')))
def test_send(self):
self.uut.send(0xabcd, b'information')
self.uut.link_socket.send.assert_called_with(
transports.BestEffortPacket.build(construct.Container(
port=0xabcd, length=15, information=b'information',
padding=b'')))
def test_send_from_socket(self):
socket = self.uut.open_socket(0xabcd, timeout=0)
socket.send(b'info')
self.uut.link_socket.send.assert_called_with(
transports.BestEffortPacket.build(construct.Container(
port=0xabcd, length=8, information=b'info', padding=b'')))
def test_receive_from_socket_with_empty_queue(self):
socket = self.uut.open_socket(0xabcd, timeout=0)
with self.assertRaises(exceptions.ReceiveQueueEmpty):
socket.receive(block=False)
def test_receive_from_socket(self):
socket = self.uut.open_socket(0xabcd, timeout=0)
self.uut.packet_received(
transports.BestEffortPacket.build(construct.Container(
port=0xabcd, length=8, information=b'info', padding=b'')))
self.assertEqual(b'info', socket.receive(block=False))
def test_receive_on_unopened_port_doesnt_reach_socket(self):
socket = self.uut.open_socket(0xabcd, timeout=0)
self.uut.packet_received(
transports.BestEffortPacket.build(construct.Container(
port=0xface, length=8, information=b'info', padding=b'')))
with self.assertRaises(exceptions.ReceiveQueueEmpty):
socket.receive(block=False)
def test_receive_malformed_packet(self):
self.uut.packet_received(b'garbage')
def test_send_equal_to_mtu(self):
self.uut.send(0xaaaa, b'a'*1496)
def test_send_greater_than_mtu(self):
with self.assertRaisesRegexp(ValueError, 'Packet length'):
self.uut.send(0xaaaa, b'a'*1497)
def test_transport_down_closes_link_socket_and_ncp(self):
self.uut.down()
self.uut.link_socket.close.assert_called_with()
self.assertIsNone(self.uut.ncp.socket)
def test_pcmp_port_closed_message_closes_socket(self):
socket = self.uut.open_socket(0xabcd, timeout=0)
self.assertFalse(socket.closed)
self.uut.packet_received(
transports.BestEffortPacket.build(construct.Container(
port=0x0001, length=7, information=b'\x81\xab\xcd',
padding=b'')))
self.assertTrue(socket.closed)
def test_pcmp_port_closed_message_without_socket(self):
self.uut.packet_received(
transports.BestEffortPacket.build(construct.Container(
port=0x0001, length=7, information=b'\x81\xaa\xaa',
padding=b'')))
class TestReliableTransportPacketBuilders(unittest.TestCase):
def test_build_info_packet(self):
self.assertEqual(
b'\x1e\x3f\xbe\xef\x00\x14Data goes here',
transports.build_reliable_info_packet(
sequence_number=15, ack_number=31, poll=True,
port=0xbeef, information=b'Data goes here'))
def test_build_receive_ready_packet(self):
self.assertEqual(
b'\x01\x18',
transports.build_reliable_supervisory_packet(
kind='RR', ack_number=12))
def test_build_receive_ready_poll_packet(self):
self.assertEqual(
b'\x01\x19',
transports.build_reliable_supervisory_packet(
kind='RR', ack_number=12, poll=True))
def test_build_receive_ready_final_packet(self):
self.assertEqual(
b'\x01\x19',
transports.build_reliable_supervisory_packet(
kind='RR', ack_number=12, final=True))
def test_build_receive_not_ready_packet(self):
self.assertEqual(
b'\x05\x18',
transports.build_reliable_supervisory_packet(
kind='RNR', ack_number=12))
def test_build_reject_packet(self):
self.assertEqual(
b'\x09\x18',
transports.build_reliable_supervisory_packet(
kind='REJ', ack_number=12))
class TestReliableTransportBeforeOpened(CommonTransportBeforeOpenedTestCases,
unittest.TestCase):
def setUp(self):
self.addCleanup(timer_helper.cancel_all_timers)
self.uut = transports.ReliableTransport(
interface=mock.MagicMock(), link_mtu=1500)
def test_open_socket_waits_for_ncp_to_open(self):
self.uut.ncp.is_Opened = mock.Mock()
self.uut.ncp.is_Opened.return_value = True
self.uut.command_socket.send = lambda packet: (
self.uut.response_packet_received(
transports.build_reliable_supervisory_packet(
kind='RR', ack_number=0, final=True)))
open_thread = RealThreadingTimer(0.01, self.uut.this_layer_up)
open_thread.daemon = True
open_thread.start()
self.assertIsNotNone(self.uut.open_socket(0xbeef, timeout=0.5))
open_thread.join()
class TestReliableTransportConnectionEstablishment(unittest.TestCase):
expected_rr_packet = transports.build_reliable_supervisory_packet(
kind='RR', ack_number=0, poll=True)
def setUp(self):
FakeTimer.clear_timer_list()
timer_patcher = mock.patch('threading.Timer', new=FakeTimer)
timer_patcher.start()
self.addCleanup(timer_patcher.stop)
control_protocol_patcher = mock.patch(
'pebble.pulse2.transports.TransportControlProtocol')
control_protocol_patcher.start()
self.addCleanup(control_protocol_patcher.stop)
self.uut = transports.ReliableTransport(
interface=mock.MagicMock(), link_mtu=1500)
assert isinstance(self.uut.ncp, mock.MagicMock)
self.uut.ncp.is_Opened.return_value = True
self.uut.this_layer_up()
def send_rr_response(self):
self.uut.response_packet_received(
transports.build_reliable_supervisory_packet(
kind='RR', ack_number=0, final=True))
def test_rr_packet_is_sent_after_this_layer_up_event(self):
self.uut.command_socket.send.assert_called_once_with(
self.expected_rr_packet)
def test_rr_command_is_retransmitted_until_response_is_received(self):
for _ in range(3):
FakeTimer.TIMERS[-1].expire()
self.send_rr_response()
self.assertFalse(FakeTimer.get_active_timers())
self.assertEqual(self.uut.command_socket.send.call_args_list,
[mock.call(self.expected_rr_packet)]*4)
self.assertIsNotNone(self.uut.open_socket(0xabcd, timeout=0))
def test_transport_negotiation_restarts_if_no_responses(self):
for _ in range(self.uut.max_retransmits):
FakeTimer.TIMERS[-1].expire()
self.assertFalse(FakeTimer.get_active_timers())
self.assertIsNone(self.uut.open_socket(0xabcd, timeout=0))
self.uut.ncp.restart.assert_called_once_with()
class TestReliableTransport(CommonTransportTestCases,
unittest.TestCase):
def setUp(self):
FakeTimer.clear_timer_list()
timer_patcher = mock.patch('threading.Timer', new=FakeTimer)
timer_patcher.start()
self.addCleanup(timer_patcher.stop)
control_protocol_patcher = mock.patch(
'pebble.pulse2.transports.TransportControlProtocol')
control_protocol_patcher.start()
self.addCleanup(control_protocol_patcher.stop)
self.uut = transports.ReliableTransport(
interface=mock.MagicMock(), link_mtu=1500)
assert isinstance(self.uut.ncp, mock.MagicMock)
self.uut.ncp.is_Opened.return_value = True
self.uut.this_layer_up()
self.uut.command_socket.send.reset_mock()
self.uut.response_packet_received(
transports.build_reliable_supervisory_packet(
kind='RR', ack_number=0, final=True))
def test_send_with_immediate_ack(self):
self.uut.send(0xbeef, b'Just some packet data')
self.uut.command_socket.send.assert_called_once_with(
transports.build_reliable_info_packet(
sequence_number=0, ack_number=0, poll=True,
port=0xbeef, information=b'Just some packet data'))
self.assertEqual(1, len(FakeTimer.get_active_timers()))
self.uut.response_packet_received(
transports.build_reliable_supervisory_packet(
kind='RR', ack_number=1, final=True))
self.assertTrue(all(t.cancelled for t in FakeTimer.TIMERS))
def test_send_with_one_timeout_before_ack(self):
self.uut.send(0xabcd, b'this will be sent twice')
active_timers = FakeTimer.get_active_timers()
self.assertEqual(1, len(active_timers))
active_timers[0].expire()
self.assertEqual(1, len(FakeTimer.get_active_timers()))
self.uut.command_socket.send.assert_has_calls(
[mock.call(transports.build_reliable_info_packet(
sequence_number=0, ack_number=0,
poll=True, port=0xabcd,
information=b'this will be sent twice'))]*2)
self.uut.response_packet_received(
transports.build_reliable_supervisory_packet(
kind='RR', ack_number=1, final=True))
self.assertTrue(all(t.cancelled for t in FakeTimer.TIMERS))
def test_send_with_no_response(self):
self.uut.send(0xd00d, b'blarg')
for _ in xrange(self.uut.max_retransmits):
FakeTimer.get_active_timers()[-1].expire()
self.uut.ncp.restart.assert_called_once_with()
def test_receive_info_packet(self):
socket = self.uut.open_socket(0xcafe, timeout=0)
self.uut.command_packet_received(transports.build_reliable_info_packet(
sequence_number=0, ack_number=0, poll=True, port=0xcafe,
information=b'info'))
self.assertEqual(b'info', socket.receive(block=False))
self.uut.response_socket.send.assert_called_once_with(
transports.build_reliable_supervisory_packet(
kind='RR', ack_number=1, final=True))
def test_receive_duplicate_packet(self):
socket = self.uut.open_socket(0xba5e, timeout=0)
packet = transports.build_reliable_info_packet(
sequence_number=0, ack_number=0, poll=True, port=0xba5e,
information=b'all your base are belong to us')
self.uut.command_packet_received(packet)
self.assertEqual(b'all your base are belong to us',
socket.receive(block=False))
self.uut.response_socket.reset_mock()
self.uut.command_packet_received(packet)
self.uut.response_socket.send.assert_called_once_with(
transports.build_reliable_supervisory_packet(
kind='RR', ack_number=1, final=True))
with self.assertRaises(exceptions.ReceiveQueueEmpty):
socket.receive(block=False)
def test_queueing_multiple_packets_to_send(self):
packets = [(0xfeed, b'Some data'),
(0x6789, b'More data'),
(0xfeed, b'Third packet')]
for protocol, information in packets:
self.uut.send(protocol, information)
for seq, (port, information) in enumerate(packets):
self.uut.command_socket.send.assert_called_once_with(
transports.build_reliable_info_packet(
sequence_number=seq, ack_number=0, poll=True,
port=port, information=information))
self.uut.command_socket.send.reset_mock()
self.uut.response_packet_received(
transports.build_reliable_supervisory_packet(
kind='RR', ack_number=seq+1, final=True))
def test_send_equal_to_mtu(self):
self.uut.send(0xaaaa, b'a'*1494)
def test_send_greater_than_mtu(self):
with self.assertRaisesRegexp(ValueError, 'Packet length'):
self.uut.send(0xaaaa, b'a'*1496)
def test_send_from_socket(self):
socket = self.uut.open_socket(0xabcd, timeout=0)
socket.send(b'info')
self.uut.command_socket.send.assert_called_with(
transports.build_reliable_info_packet(
sequence_number=0, ack_number=0,
poll=True, port=0xabcd, information=b'info'))
def test_receive_from_socket_with_empty_queue(self):
socket = self.uut.open_socket(0xabcd, timeout=0)
with self.assertRaises(exceptions.ReceiveQueueEmpty):
socket.receive(block=False)
def test_receive_from_socket(self):
socket = self.uut.open_socket(0xabcd, timeout=0)
self.uut.command_packet_received(transports.build_reliable_info_packet(
sequence_number=0, ack_number=0, poll=True, port=0xabcd,
information=b'info info info'))
self.assertEqual(b'info info info', socket.receive(block=False))
def test_receive_on_unopened_port_doesnt_reach_socket(self):
socket = self.uut.open_socket(0xabcd, timeout=0)
self.uut.command_packet_received(transports.build_reliable_info_packet(
sequence_number=0, ack_number=0, poll=True, port=0x3333,
information=b'info'))
with self.assertRaises(exceptions.ReceiveQueueEmpty):
socket.receive(block=False)
def test_receive_malformed_command_packet(self):
self.uut.command_packet_received(b'garbage')
self.uut.ncp.restart.assert_called_once_with()
def test_receive_malformed_response_packet(self):
self.uut.response_packet_received(b'garbage')
self.uut.ncp.restart.assert_called_once_with()
def test_transport_down_closes_link_sockets_and_ncp(self):
self.uut.down()
self.uut.command_socket.close.assert_called_with()
self.uut.response_socket.close.assert_called_with()
self.uut.ncp.down.assert_called_with()
def test_pcmp_port_closed_message_closes_socket(self):
socket = self.uut.open_socket(0xabcd, timeout=0)
self.assertFalse(socket.closed)
self.uut.command_packet_received(transports.build_reliable_info_packet(
sequence_number=0, ack_number=0, poll=True, port=0x0001,
information=b'\x81\xab\xcd'))
self.assertTrue(socket.closed)
def test_pcmp_port_closed_message_without_socket(self):
self.uut.command_packet_received(transports.build_reliable_info_packet(
sequence_number=0, ack_number=0, poll=True, port=0x0001,
information=b'\x81\xaa\xaa'))
class TestSocket(unittest.TestCase):
def setUp(self):
self.uut = transports.Socket(mock.Mock(), 1234)
def test_empty_receive_queue(self):
with self.assertRaises(exceptions.ReceiveQueueEmpty):
self.uut.receive(block=False)
def test_empty_receive_queue_blocking(self):
with self.assertRaises(exceptions.ReceiveQueueEmpty):
self.uut.receive(timeout=0.001)
def test_receive(self):
self.uut.on_receive(b'data')
self.assertEqual(b'data', self.uut.receive(block=False))
with self.assertRaises(exceptions.ReceiveQueueEmpty):
self.uut.receive(block=False)
def test_receive_twice(self):
self.uut.on_receive(b'one')
self.uut.on_receive(b'two')
self.assertEqual(b'one', self.uut.receive(block=False))
self.assertEqual(b'two', self.uut.receive(block=False))
def test_receive_interleaved(self):
self.uut.on_receive(b'one')
self.assertEqual(b'one', self.uut.receive(block=False))
self.uut.on_receive(b'two')
self.assertEqual(b'two', self.uut.receive(block=False))
def test_send(self):
self.uut.send(b'data')
self.uut.transport.send.assert_called_once_with(1234, b'data')
def test_close(self):
self.uut.close()
self.uut.transport.unregister_socket.assert_called_once_with(1234)
def test_send_after_close_is_an_error(self):
self.uut.close()
with self.assertRaises(exceptions.SocketClosed):
self.uut.send(b'data')
def test_receive_after_close_is_an_error(self):
self.uut.close()
with self.assertRaises(exceptions.SocketClosed):
self.uut.receive(block=False)
def test_blocking_receive_after_close_is_an_error(self):
self.uut.close()
with self.assertRaises(exceptions.SocketClosed):
self.uut.receive(timeout=0.001)
def test_close_during_blocking_receive_aborts_the_receive(self):
thread_started = threading.Event()
result = [None]
def test_thread():
thread_started.set()
try:
self.uut.receive(timeout=0.3)
except Exception as e:
result[0] = e
thread = threading.Thread(target=test_thread)
thread.daemon = True
thread.start()
assert thread_started.wait(timeout=0.5)
self.uut.close()
thread.join()
self.assertIsInstance(result[0], exceptions.SocketClosed)
def test_close_is_idempotent(self):
self.uut.close()
self.uut.close()
self.assertEqual(1, self.uut.transport.unregister_socket.call_count)

View file

@ -0,0 +1,25 @@
# Copyright 2024 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import threading
def cancel_all_timers():
'''Cancel all running timer threads in the process.
'''
for thread in threading.enumerate():
try:
thread.cancel()
except AttributeError:
pass