pebble/python_libs/pulse2/tests/test_transports.py
2025-01-27 11:38:16 -08:00

538 lines
22 KiB
Python

# Copyright 2024 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import absolute_import
import threading
import unittest
try:
from unittest import mock
except ImportError:
import mock
import construct
from pebble.pulse2 import exceptions, pcmp, transports
from .fake_timer import FakeTimer
from . import timer_helper
# Save a reference to the real threading.Timer for tests which need to
# use timers even while threading.Timer is patched with FakeTimer.
RealThreadingTimer = threading.Timer
class CommonTransportBeforeOpenedTestCases(object):
def test_send_raises_exception(self):
with self.assertRaises(exceptions.TransportNotReady):
self.uut.send(0xdead, b'not gonna get through')
def test_open_socket_returns_None_when_ncp_fails_to_open(self):
self.assertIsNone(self.uut.open_socket(0xbeef, timeout=0))
class CommonTransportTestCases(object):
def test_send_raises_exception_after_transport_is_closed(self):
self.uut.down()
with self.assertRaises(exceptions.TransportNotReady):
self.uut.send(0xaaaa, b'asdf')
def test_socket_is_closed_when_transport_is_closed(self):
socket = self.uut.open_socket(0xabcd, timeout=0)
self.uut.down()
self.assertTrue(socket.closed)
with self.assertRaises(exceptions.SocketClosed):
socket.send(b'foo')
def test_opening_two_sockets_on_same_port_is_an_error(self):
socket1 = self.uut.open_socket(0xabcd, timeout=0)
with self.assertRaises(KeyError):
socket2 = self.uut.open_socket(0xabcd, timeout=0)
def test_closing_a_socket_allows_another_to_be_opened(self):
socket1 = self.uut.open_socket(0xabcd, timeout=0)
socket1.close()
socket2 = self.uut.open_socket(0xabcd, timeout=0)
def test_opening_socket_fails_after_transport_down(self):
self.uut.this_layer_down()
self.assertIsNone(self.uut.open_socket(0xabcd, timeout=0))
def test_opening_socket_succeeds_after_transport_bounces(self):
self.uut.this_layer_down()
self.uut.this_layer_up()
self.uut.open_socket(0xabcd, timeout=0)
class TestBestEffortTransportBeforeOpened(CommonTransportBeforeOpenedTestCases,
unittest.TestCase):
def setUp(self):
control_protocol_patcher = mock.patch(
'pebble.pulse2.transports.TransportControlProtocol')
control_protocol_patcher.start()
self.addCleanup(control_protocol_patcher.stop)
self.uut = transports.BestEffortApplicationTransport(
interface=mock.MagicMock(), link_mtu=1500)
self.uut.ncp.is_Opened.return_value = False
def test_open_socket_waits_for_ncp_to_open(self):
self.uut.ncp.is_Opened.return_value = True
def on_ping(cb, *args):
self.uut.packet_received(transports.BestEffortPacket.build(
construct.Container(port=0x0001, length=5,
information=b'\x02', padding=b'')))
cb(True)
with mock.patch.object(pcmp.PulseControlMessageProtocol, 'ping') \
as mock_ping:
mock_ping.side_effect = on_ping
open_thread = RealThreadingTimer(0.01, self.uut.this_layer_up)
open_thread.daemon = True
open_thread.start()
self.assertIsNotNone(self.uut.open_socket(0xbeef, timeout=0.5))
open_thread.join()
class TestBestEffortTransport(CommonTransportTestCases, unittest.TestCase):
def setUp(self):
self.addCleanup(timer_helper.cancel_all_timers)
self.uut = transports.BestEffortApplicationTransport(
interface=mock.MagicMock(), link_mtu=1500)
self.uut.ncp.receive_configure_request_acceptable(0, [])
self.uut.ncp.receive_configure_ack()
self.uut.packet_received(transports.BestEffortPacket.build(
construct.Container(port=0x0001, length=5,
information=b'\x02', padding=b'')))
def test_send(self):
self.uut.send(0xabcd, b'information')
self.uut.link_socket.send.assert_called_with(
transports.BestEffortPacket.build(construct.Container(
port=0xabcd, length=15, information=b'information',
padding=b'')))
def test_send_from_socket(self):
socket = self.uut.open_socket(0xabcd, timeout=0)
socket.send(b'info')
self.uut.link_socket.send.assert_called_with(
transports.BestEffortPacket.build(construct.Container(
port=0xabcd, length=8, information=b'info', padding=b'')))
def test_receive_from_socket_with_empty_queue(self):
socket = self.uut.open_socket(0xabcd, timeout=0)
with self.assertRaises(exceptions.ReceiveQueueEmpty):
socket.receive(block=False)
def test_receive_from_socket(self):
socket = self.uut.open_socket(0xabcd, timeout=0)
self.uut.packet_received(
transports.BestEffortPacket.build(construct.Container(
port=0xabcd, length=8, information=b'info', padding=b'')))
self.assertEqual(b'info', socket.receive(block=False))
def test_receive_on_unopened_port_doesnt_reach_socket(self):
socket = self.uut.open_socket(0xabcd, timeout=0)
self.uut.packet_received(
transports.BestEffortPacket.build(construct.Container(
port=0xface, length=8, information=b'info', padding=b'')))
with self.assertRaises(exceptions.ReceiveQueueEmpty):
socket.receive(block=False)
def test_receive_malformed_packet(self):
self.uut.packet_received(b'garbage')
def test_send_equal_to_mtu(self):
self.uut.send(0xaaaa, b'a'*1496)
def test_send_greater_than_mtu(self):
with self.assertRaisesRegexp(ValueError, 'Packet length'):
self.uut.send(0xaaaa, b'a'*1497)
def test_transport_down_closes_link_socket_and_ncp(self):
self.uut.down()
self.uut.link_socket.close.assert_called_with()
self.assertIsNone(self.uut.ncp.socket)
def test_pcmp_port_closed_message_closes_socket(self):
socket = self.uut.open_socket(0xabcd, timeout=0)
self.assertFalse(socket.closed)
self.uut.packet_received(
transports.BestEffortPacket.build(construct.Container(
port=0x0001, length=7, information=b'\x81\xab\xcd',
padding=b'')))
self.assertTrue(socket.closed)
def test_pcmp_port_closed_message_without_socket(self):
self.uut.packet_received(
transports.BestEffortPacket.build(construct.Container(
port=0x0001, length=7, information=b'\x81\xaa\xaa',
padding=b'')))
class TestReliableTransportPacketBuilders(unittest.TestCase):
def test_build_info_packet(self):
self.assertEqual(
b'\x1e\x3f\xbe\xef\x00\x14Data goes here',
transports.build_reliable_info_packet(
sequence_number=15, ack_number=31, poll=True,
port=0xbeef, information=b'Data goes here'))
def test_build_receive_ready_packet(self):
self.assertEqual(
b'\x01\x18',
transports.build_reliable_supervisory_packet(
kind='RR', ack_number=12))
def test_build_receive_ready_poll_packet(self):
self.assertEqual(
b'\x01\x19',
transports.build_reliable_supervisory_packet(
kind='RR', ack_number=12, poll=True))
def test_build_receive_ready_final_packet(self):
self.assertEqual(
b'\x01\x19',
transports.build_reliable_supervisory_packet(
kind='RR', ack_number=12, final=True))
def test_build_receive_not_ready_packet(self):
self.assertEqual(
b'\x05\x18',
transports.build_reliable_supervisory_packet(
kind='RNR', ack_number=12))
def test_build_reject_packet(self):
self.assertEqual(
b'\x09\x18',
transports.build_reliable_supervisory_packet(
kind='REJ', ack_number=12))
class TestReliableTransportBeforeOpened(CommonTransportBeforeOpenedTestCases,
unittest.TestCase):
def setUp(self):
self.addCleanup(timer_helper.cancel_all_timers)
self.uut = transports.ReliableTransport(
interface=mock.MagicMock(), link_mtu=1500)
def test_open_socket_waits_for_ncp_to_open(self):
self.uut.ncp.is_Opened = mock.Mock()
self.uut.ncp.is_Opened.return_value = True
self.uut.command_socket.send = lambda packet: (
self.uut.response_packet_received(
transports.build_reliable_supervisory_packet(
kind='RR', ack_number=0, final=True)))
open_thread = RealThreadingTimer(0.01, self.uut.this_layer_up)
open_thread.daemon = True
open_thread.start()
self.assertIsNotNone(self.uut.open_socket(0xbeef, timeout=0.5))
open_thread.join()
class TestReliableTransportConnectionEstablishment(unittest.TestCase):
expected_rr_packet = transports.build_reliable_supervisory_packet(
kind='RR', ack_number=0, poll=True)
def setUp(self):
FakeTimer.clear_timer_list()
timer_patcher = mock.patch('threading.Timer', new=FakeTimer)
timer_patcher.start()
self.addCleanup(timer_patcher.stop)
control_protocol_patcher = mock.patch(
'pebble.pulse2.transports.TransportControlProtocol')
control_protocol_patcher.start()
self.addCleanup(control_protocol_patcher.stop)
self.uut = transports.ReliableTransport(
interface=mock.MagicMock(), link_mtu=1500)
assert isinstance(self.uut.ncp, mock.MagicMock)
self.uut.ncp.is_Opened.return_value = True
self.uut.this_layer_up()
def send_rr_response(self):
self.uut.response_packet_received(
transports.build_reliable_supervisory_packet(
kind='RR', ack_number=0, final=True))
def test_rr_packet_is_sent_after_this_layer_up_event(self):
self.uut.command_socket.send.assert_called_once_with(
self.expected_rr_packet)
def test_rr_command_is_retransmitted_until_response_is_received(self):
for _ in range(3):
FakeTimer.TIMERS[-1].expire()
self.send_rr_response()
self.assertFalse(FakeTimer.get_active_timers())
self.assertEqual(self.uut.command_socket.send.call_args_list,
[mock.call(self.expected_rr_packet)]*4)
self.assertIsNotNone(self.uut.open_socket(0xabcd, timeout=0))
def test_transport_negotiation_restarts_if_no_responses(self):
for _ in range(self.uut.max_retransmits):
FakeTimer.TIMERS[-1].expire()
self.assertFalse(FakeTimer.get_active_timers())
self.assertIsNone(self.uut.open_socket(0xabcd, timeout=0))
self.uut.ncp.restart.assert_called_once_with()
class TestReliableTransport(CommonTransportTestCases,
unittest.TestCase):
def setUp(self):
FakeTimer.clear_timer_list()
timer_patcher = mock.patch('threading.Timer', new=FakeTimer)
timer_patcher.start()
self.addCleanup(timer_patcher.stop)
control_protocol_patcher = mock.patch(
'pebble.pulse2.transports.TransportControlProtocol')
control_protocol_patcher.start()
self.addCleanup(control_protocol_patcher.stop)
self.uut = transports.ReliableTransport(
interface=mock.MagicMock(), link_mtu=1500)
assert isinstance(self.uut.ncp, mock.MagicMock)
self.uut.ncp.is_Opened.return_value = True
self.uut.this_layer_up()
self.uut.command_socket.send.reset_mock()
self.uut.response_packet_received(
transports.build_reliable_supervisory_packet(
kind='RR', ack_number=0, final=True))
def test_send_with_immediate_ack(self):
self.uut.send(0xbeef, b'Just some packet data')
self.uut.command_socket.send.assert_called_once_with(
transports.build_reliable_info_packet(
sequence_number=0, ack_number=0, poll=True,
port=0xbeef, information=b'Just some packet data'))
self.assertEqual(1, len(FakeTimer.get_active_timers()))
self.uut.response_packet_received(
transports.build_reliable_supervisory_packet(
kind='RR', ack_number=1, final=True))
self.assertTrue(all(t.cancelled for t in FakeTimer.TIMERS))
def test_send_with_one_timeout_before_ack(self):
self.uut.send(0xabcd, b'this will be sent twice')
active_timers = FakeTimer.get_active_timers()
self.assertEqual(1, len(active_timers))
active_timers[0].expire()
self.assertEqual(1, len(FakeTimer.get_active_timers()))
self.uut.command_socket.send.assert_has_calls(
[mock.call(transports.build_reliable_info_packet(
sequence_number=0, ack_number=0,
poll=True, port=0xabcd,
information=b'this will be sent twice'))]*2)
self.uut.response_packet_received(
transports.build_reliable_supervisory_packet(
kind='RR', ack_number=1, final=True))
self.assertTrue(all(t.cancelled for t in FakeTimer.TIMERS))
def test_send_with_no_response(self):
self.uut.send(0xd00d, b'blarg')
for _ in xrange(self.uut.max_retransmits):
FakeTimer.get_active_timers()[-1].expire()
self.uut.ncp.restart.assert_called_once_with()
def test_receive_info_packet(self):
socket = self.uut.open_socket(0xcafe, timeout=0)
self.uut.command_packet_received(transports.build_reliable_info_packet(
sequence_number=0, ack_number=0, poll=True, port=0xcafe,
information=b'info'))
self.assertEqual(b'info', socket.receive(block=False))
self.uut.response_socket.send.assert_called_once_with(
transports.build_reliable_supervisory_packet(
kind='RR', ack_number=1, final=True))
def test_receive_duplicate_packet(self):
socket = self.uut.open_socket(0xba5e, timeout=0)
packet = transports.build_reliable_info_packet(
sequence_number=0, ack_number=0, poll=True, port=0xba5e,
information=b'all your base are belong to us')
self.uut.command_packet_received(packet)
self.assertEqual(b'all your base are belong to us',
socket.receive(block=False))
self.uut.response_socket.reset_mock()
self.uut.command_packet_received(packet)
self.uut.response_socket.send.assert_called_once_with(
transports.build_reliable_supervisory_packet(
kind='RR', ack_number=1, final=True))
with self.assertRaises(exceptions.ReceiveQueueEmpty):
socket.receive(block=False)
def test_queueing_multiple_packets_to_send(self):
packets = [(0xfeed, b'Some data'),
(0x6789, b'More data'),
(0xfeed, b'Third packet')]
for protocol, information in packets:
self.uut.send(protocol, information)
for seq, (port, information) in enumerate(packets):
self.uut.command_socket.send.assert_called_once_with(
transports.build_reliable_info_packet(
sequence_number=seq, ack_number=0, poll=True,
port=port, information=information))
self.uut.command_socket.send.reset_mock()
self.uut.response_packet_received(
transports.build_reliable_supervisory_packet(
kind='RR', ack_number=seq+1, final=True))
def test_send_equal_to_mtu(self):
self.uut.send(0xaaaa, b'a'*1494)
def test_send_greater_than_mtu(self):
with self.assertRaisesRegexp(ValueError, 'Packet length'):
self.uut.send(0xaaaa, b'a'*1496)
def test_send_from_socket(self):
socket = self.uut.open_socket(0xabcd, timeout=0)
socket.send(b'info')
self.uut.command_socket.send.assert_called_with(
transports.build_reliable_info_packet(
sequence_number=0, ack_number=0,
poll=True, port=0xabcd, information=b'info'))
def test_receive_from_socket_with_empty_queue(self):
socket = self.uut.open_socket(0xabcd, timeout=0)
with self.assertRaises(exceptions.ReceiveQueueEmpty):
socket.receive(block=False)
def test_receive_from_socket(self):
socket = self.uut.open_socket(0xabcd, timeout=0)
self.uut.command_packet_received(transports.build_reliable_info_packet(
sequence_number=0, ack_number=0, poll=True, port=0xabcd,
information=b'info info info'))
self.assertEqual(b'info info info', socket.receive(block=False))
def test_receive_on_unopened_port_doesnt_reach_socket(self):
socket = self.uut.open_socket(0xabcd, timeout=0)
self.uut.command_packet_received(transports.build_reliable_info_packet(
sequence_number=0, ack_number=0, poll=True, port=0x3333,
information=b'info'))
with self.assertRaises(exceptions.ReceiveQueueEmpty):
socket.receive(block=False)
def test_receive_malformed_command_packet(self):
self.uut.command_packet_received(b'garbage')
self.uut.ncp.restart.assert_called_once_with()
def test_receive_malformed_response_packet(self):
self.uut.response_packet_received(b'garbage')
self.uut.ncp.restart.assert_called_once_with()
def test_transport_down_closes_link_sockets_and_ncp(self):
self.uut.down()
self.uut.command_socket.close.assert_called_with()
self.uut.response_socket.close.assert_called_with()
self.uut.ncp.down.assert_called_with()
def test_pcmp_port_closed_message_closes_socket(self):
socket = self.uut.open_socket(0xabcd, timeout=0)
self.assertFalse(socket.closed)
self.uut.command_packet_received(transports.build_reliable_info_packet(
sequence_number=0, ack_number=0, poll=True, port=0x0001,
information=b'\x81\xab\xcd'))
self.assertTrue(socket.closed)
def test_pcmp_port_closed_message_without_socket(self):
self.uut.command_packet_received(transports.build_reliable_info_packet(
sequence_number=0, ack_number=0, poll=True, port=0x0001,
information=b'\x81\xaa\xaa'))
class TestSocket(unittest.TestCase):
def setUp(self):
self.uut = transports.Socket(mock.Mock(), 1234)
def test_empty_receive_queue(self):
with self.assertRaises(exceptions.ReceiveQueueEmpty):
self.uut.receive(block=False)
def test_empty_receive_queue_blocking(self):
with self.assertRaises(exceptions.ReceiveQueueEmpty):
self.uut.receive(timeout=0.001)
def test_receive(self):
self.uut.on_receive(b'data')
self.assertEqual(b'data', self.uut.receive(block=False))
with self.assertRaises(exceptions.ReceiveQueueEmpty):
self.uut.receive(block=False)
def test_receive_twice(self):
self.uut.on_receive(b'one')
self.uut.on_receive(b'two')
self.assertEqual(b'one', self.uut.receive(block=False))
self.assertEqual(b'two', self.uut.receive(block=False))
def test_receive_interleaved(self):
self.uut.on_receive(b'one')
self.assertEqual(b'one', self.uut.receive(block=False))
self.uut.on_receive(b'two')
self.assertEqual(b'two', self.uut.receive(block=False))
def test_send(self):
self.uut.send(b'data')
self.uut.transport.send.assert_called_once_with(1234, b'data')
def test_close(self):
self.uut.close()
self.uut.transport.unregister_socket.assert_called_once_with(1234)
def test_send_after_close_is_an_error(self):
self.uut.close()
with self.assertRaises(exceptions.SocketClosed):
self.uut.send(b'data')
def test_receive_after_close_is_an_error(self):
self.uut.close()
with self.assertRaises(exceptions.SocketClosed):
self.uut.receive(block=False)
def test_blocking_receive_after_close_is_an_error(self):
self.uut.close()
with self.assertRaises(exceptions.SocketClosed):
self.uut.receive(timeout=0.001)
def test_close_during_blocking_receive_aborts_the_receive(self):
thread_started = threading.Event()
result = [None]
def test_thread():
thread_started.set()
try:
self.uut.receive(timeout=0.3)
except Exception as e:
result[0] = e
thread = threading.Thread(target=test_thread)
thread.daemon = True
thread.start()
assert thread_started.wait(timeout=0.5)
self.uut.close()
thread.join()
self.assertIsInstance(result[0], exceptions.SocketClosed)
def test_close_is_idempotent(self):
self.uut.close()
self.uut.close()
self.assertEqual(1, self.uut.transport.unregister_socket.call_count)