# Copyright 2024 Google LLC # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. from __future__ import absolute_import import threading import unittest try: from unittest import mock except ImportError: import mock import construct from pebble.pulse2 import exceptions, pcmp, transports from .fake_timer import FakeTimer from . import timer_helper # Save a reference to the real threading.Timer for tests which need to # use timers even while threading.Timer is patched with FakeTimer. RealThreadingTimer = threading.Timer class CommonTransportBeforeOpenedTestCases(object): def test_send_raises_exception(self): with self.assertRaises(exceptions.TransportNotReady): self.uut.send(0xdead, b'not gonna get through') def test_open_socket_returns_None_when_ncp_fails_to_open(self): self.assertIsNone(self.uut.open_socket(0xbeef, timeout=0)) class CommonTransportTestCases(object): def test_send_raises_exception_after_transport_is_closed(self): self.uut.down() with self.assertRaises(exceptions.TransportNotReady): self.uut.send(0xaaaa, b'asdf') def test_socket_is_closed_when_transport_is_closed(self): socket = self.uut.open_socket(0xabcd, timeout=0) self.uut.down() self.assertTrue(socket.closed) with self.assertRaises(exceptions.SocketClosed): socket.send(b'foo') def test_opening_two_sockets_on_same_port_is_an_error(self): socket1 = self.uut.open_socket(0xabcd, timeout=0) with self.assertRaises(KeyError): socket2 = self.uut.open_socket(0xabcd, timeout=0) def test_closing_a_socket_allows_another_to_be_opened(self): socket1 = self.uut.open_socket(0xabcd, timeout=0) socket1.close() socket2 = self.uut.open_socket(0xabcd, timeout=0) def test_opening_socket_fails_after_transport_down(self): self.uut.this_layer_down() self.assertIsNone(self.uut.open_socket(0xabcd, timeout=0)) def test_opening_socket_succeeds_after_transport_bounces(self): self.uut.this_layer_down() self.uut.this_layer_up() self.uut.open_socket(0xabcd, timeout=0) class TestBestEffortTransportBeforeOpened(CommonTransportBeforeOpenedTestCases, unittest.TestCase): def setUp(self): control_protocol_patcher = mock.patch( 'pebble.pulse2.transports.TransportControlProtocol') control_protocol_patcher.start() self.addCleanup(control_protocol_patcher.stop) self.uut = transports.BestEffortApplicationTransport( interface=mock.MagicMock(), link_mtu=1500) self.uut.ncp.is_Opened.return_value = False def test_open_socket_waits_for_ncp_to_open(self): self.uut.ncp.is_Opened.return_value = True def on_ping(cb, *args): self.uut.packet_received(transports.BestEffortPacket.build( construct.Container(port=0x0001, length=5, information=b'\x02', padding=b''))) cb(True) with mock.patch.object(pcmp.PulseControlMessageProtocol, 'ping') \ as mock_ping: mock_ping.side_effect = on_ping open_thread = RealThreadingTimer(0.01, self.uut.this_layer_up) open_thread.daemon = True open_thread.start() self.assertIsNotNone(self.uut.open_socket(0xbeef, timeout=0.5)) open_thread.join() class TestBestEffortTransport(CommonTransportTestCases, unittest.TestCase): def setUp(self): self.addCleanup(timer_helper.cancel_all_timers) self.uut = transports.BestEffortApplicationTransport( interface=mock.MagicMock(), link_mtu=1500) self.uut.ncp.receive_configure_request_acceptable(0, []) self.uut.ncp.receive_configure_ack() self.uut.packet_received(transports.BestEffortPacket.build( construct.Container(port=0x0001, length=5, information=b'\x02', padding=b''))) def test_send(self): self.uut.send(0xabcd, b'information') self.uut.link_socket.send.assert_called_with( transports.BestEffortPacket.build(construct.Container( port=0xabcd, length=15, information=b'information', padding=b''))) def test_send_from_socket(self): socket = self.uut.open_socket(0xabcd, timeout=0) socket.send(b'info') self.uut.link_socket.send.assert_called_with( transports.BestEffortPacket.build(construct.Container( port=0xabcd, length=8, information=b'info', padding=b''))) def test_receive_from_socket_with_empty_queue(self): socket = self.uut.open_socket(0xabcd, timeout=0) with self.assertRaises(exceptions.ReceiveQueueEmpty): socket.receive(block=False) def test_receive_from_socket(self): socket = self.uut.open_socket(0xabcd, timeout=0) self.uut.packet_received( transports.BestEffortPacket.build(construct.Container( port=0xabcd, length=8, information=b'info', padding=b''))) self.assertEqual(b'info', socket.receive(block=False)) def test_receive_on_unopened_port_doesnt_reach_socket(self): socket = self.uut.open_socket(0xabcd, timeout=0) self.uut.packet_received( transports.BestEffortPacket.build(construct.Container( port=0xface, length=8, information=b'info', padding=b''))) with self.assertRaises(exceptions.ReceiveQueueEmpty): socket.receive(block=False) def test_receive_malformed_packet(self): self.uut.packet_received(b'garbage') def test_send_equal_to_mtu(self): self.uut.send(0xaaaa, b'a'*1496) def test_send_greater_than_mtu(self): with self.assertRaisesRegexp(ValueError, 'Packet length'): self.uut.send(0xaaaa, b'a'*1497) def test_transport_down_closes_link_socket_and_ncp(self): self.uut.down() self.uut.link_socket.close.assert_called_with() self.assertIsNone(self.uut.ncp.socket) def test_pcmp_port_closed_message_closes_socket(self): socket = self.uut.open_socket(0xabcd, timeout=0) self.assertFalse(socket.closed) self.uut.packet_received( transports.BestEffortPacket.build(construct.Container( port=0x0001, length=7, information=b'\x81\xab\xcd', padding=b''))) self.assertTrue(socket.closed) def test_pcmp_port_closed_message_without_socket(self): self.uut.packet_received( transports.BestEffortPacket.build(construct.Container( port=0x0001, length=7, information=b'\x81\xaa\xaa', padding=b''))) class TestReliableTransportPacketBuilders(unittest.TestCase): def test_build_info_packet(self): self.assertEqual( b'\x1e\x3f\xbe\xef\x00\x14Data goes here', transports.build_reliable_info_packet( sequence_number=15, ack_number=31, poll=True, port=0xbeef, information=b'Data goes here')) def test_build_receive_ready_packet(self): self.assertEqual( b'\x01\x18', transports.build_reliable_supervisory_packet( kind='RR', ack_number=12)) def test_build_receive_ready_poll_packet(self): self.assertEqual( b'\x01\x19', transports.build_reliable_supervisory_packet( kind='RR', ack_number=12, poll=True)) def test_build_receive_ready_final_packet(self): self.assertEqual( b'\x01\x19', transports.build_reliable_supervisory_packet( kind='RR', ack_number=12, final=True)) def test_build_receive_not_ready_packet(self): self.assertEqual( b'\x05\x18', transports.build_reliable_supervisory_packet( kind='RNR', ack_number=12)) def test_build_reject_packet(self): self.assertEqual( b'\x09\x18', transports.build_reliable_supervisory_packet( kind='REJ', ack_number=12)) class TestReliableTransportBeforeOpened(CommonTransportBeforeOpenedTestCases, unittest.TestCase): def setUp(self): self.addCleanup(timer_helper.cancel_all_timers) self.uut = transports.ReliableTransport( interface=mock.MagicMock(), link_mtu=1500) def test_open_socket_waits_for_ncp_to_open(self): self.uut.ncp.is_Opened = mock.Mock() self.uut.ncp.is_Opened.return_value = True self.uut.command_socket.send = lambda packet: ( self.uut.response_packet_received( transports.build_reliable_supervisory_packet( kind='RR', ack_number=0, final=True))) open_thread = RealThreadingTimer(0.01, self.uut.this_layer_up) open_thread.daemon = True open_thread.start() self.assertIsNotNone(self.uut.open_socket(0xbeef, timeout=0.5)) open_thread.join() class TestReliableTransportConnectionEstablishment(unittest.TestCase): expected_rr_packet = transports.build_reliable_supervisory_packet( kind='RR', ack_number=0, poll=True) def setUp(self): FakeTimer.clear_timer_list() timer_patcher = mock.patch('threading.Timer', new=FakeTimer) timer_patcher.start() self.addCleanup(timer_patcher.stop) control_protocol_patcher = mock.patch( 'pebble.pulse2.transports.TransportControlProtocol') control_protocol_patcher.start() self.addCleanup(control_protocol_patcher.stop) self.uut = transports.ReliableTransport( interface=mock.MagicMock(), link_mtu=1500) assert isinstance(self.uut.ncp, mock.MagicMock) self.uut.ncp.is_Opened.return_value = True self.uut.this_layer_up() def send_rr_response(self): self.uut.response_packet_received( transports.build_reliable_supervisory_packet( kind='RR', ack_number=0, final=True)) def test_rr_packet_is_sent_after_this_layer_up_event(self): self.uut.command_socket.send.assert_called_once_with( self.expected_rr_packet) def test_rr_command_is_retransmitted_until_response_is_received(self): for _ in range(3): FakeTimer.TIMERS[-1].expire() self.send_rr_response() self.assertFalse(FakeTimer.get_active_timers()) self.assertEqual(self.uut.command_socket.send.call_args_list, [mock.call(self.expected_rr_packet)]*4) self.assertIsNotNone(self.uut.open_socket(0xabcd, timeout=0)) def test_transport_negotiation_restarts_if_no_responses(self): for _ in range(self.uut.max_retransmits): FakeTimer.TIMERS[-1].expire() self.assertFalse(FakeTimer.get_active_timers()) self.assertIsNone(self.uut.open_socket(0xabcd, timeout=0)) self.uut.ncp.restart.assert_called_once_with() class TestReliableTransport(CommonTransportTestCases, unittest.TestCase): def setUp(self): FakeTimer.clear_timer_list() timer_patcher = mock.patch('threading.Timer', new=FakeTimer) timer_patcher.start() self.addCleanup(timer_patcher.stop) control_protocol_patcher = mock.patch( 'pebble.pulse2.transports.TransportControlProtocol') control_protocol_patcher.start() self.addCleanup(control_protocol_patcher.stop) self.uut = transports.ReliableTransport( interface=mock.MagicMock(), link_mtu=1500) assert isinstance(self.uut.ncp, mock.MagicMock) self.uut.ncp.is_Opened.return_value = True self.uut.this_layer_up() self.uut.command_socket.send.reset_mock() self.uut.response_packet_received( transports.build_reliable_supervisory_packet( kind='RR', ack_number=0, final=True)) def test_send_with_immediate_ack(self): self.uut.send(0xbeef, b'Just some packet data') self.uut.command_socket.send.assert_called_once_with( transports.build_reliable_info_packet( sequence_number=0, ack_number=0, poll=True, port=0xbeef, information=b'Just some packet data')) self.assertEqual(1, len(FakeTimer.get_active_timers())) self.uut.response_packet_received( transports.build_reliable_supervisory_packet( kind='RR', ack_number=1, final=True)) self.assertTrue(all(t.cancelled for t in FakeTimer.TIMERS)) def test_send_with_one_timeout_before_ack(self): self.uut.send(0xabcd, b'this will be sent twice') active_timers = FakeTimer.get_active_timers() self.assertEqual(1, len(active_timers)) active_timers[0].expire() self.assertEqual(1, len(FakeTimer.get_active_timers())) self.uut.command_socket.send.assert_has_calls( [mock.call(transports.build_reliable_info_packet( sequence_number=0, ack_number=0, poll=True, port=0xabcd, information=b'this will be sent twice'))]*2) self.uut.response_packet_received( transports.build_reliable_supervisory_packet( kind='RR', ack_number=1, final=True)) self.assertTrue(all(t.cancelled for t in FakeTimer.TIMERS)) def test_send_with_no_response(self): self.uut.send(0xd00d, b'blarg') for _ in xrange(self.uut.max_retransmits): FakeTimer.get_active_timers()[-1].expire() self.uut.ncp.restart.assert_called_once_with() def test_receive_info_packet(self): socket = self.uut.open_socket(0xcafe, timeout=0) self.uut.command_packet_received(transports.build_reliable_info_packet( sequence_number=0, ack_number=0, poll=True, port=0xcafe, information=b'info')) self.assertEqual(b'info', socket.receive(block=False)) self.uut.response_socket.send.assert_called_once_with( transports.build_reliable_supervisory_packet( kind='RR', ack_number=1, final=True)) def test_receive_duplicate_packet(self): socket = self.uut.open_socket(0xba5e, timeout=0) packet = transports.build_reliable_info_packet( sequence_number=0, ack_number=0, poll=True, port=0xba5e, information=b'all your base are belong to us') self.uut.command_packet_received(packet) self.assertEqual(b'all your base are belong to us', socket.receive(block=False)) self.uut.response_socket.reset_mock() self.uut.command_packet_received(packet) self.uut.response_socket.send.assert_called_once_with( transports.build_reliable_supervisory_packet( kind='RR', ack_number=1, final=True)) with self.assertRaises(exceptions.ReceiveQueueEmpty): socket.receive(block=False) def test_queueing_multiple_packets_to_send(self): packets = [(0xfeed, b'Some data'), (0x6789, b'More data'), (0xfeed, b'Third packet')] for protocol, information in packets: self.uut.send(protocol, information) for seq, (port, information) in enumerate(packets): self.uut.command_socket.send.assert_called_once_with( transports.build_reliable_info_packet( sequence_number=seq, ack_number=0, poll=True, port=port, information=information)) self.uut.command_socket.send.reset_mock() self.uut.response_packet_received( transports.build_reliable_supervisory_packet( kind='RR', ack_number=seq+1, final=True)) def test_send_equal_to_mtu(self): self.uut.send(0xaaaa, b'a'*1494) def test_send_greater_than_mtu(self): with self.assertRaisesRegexp(ValueError, 'Packet length'): self.uut.send(0xaaaa, b'a'*1496) def test_send_from_socket(self): socket = self.uut.open_socket(0xabcd, timeout=0) socket.send(b'info') self.uut.command_socket.send.assert_called_with( transports.build_reliable_info_packet( sequence_number=0, ack_number=0, poll=True, port=0xabcd, information=b'info')) def test_receive_from_socket_with_empty_queue(self): socket = self.uut.open_socket(0xabcd, timeout=0) with self.assertRaises(exceptions.ReceiveQueueEmpty): socket.receive(block=False) def test_receive_from_socket(self): socket = self.uut.open_socket(0xabcd, timeout=0) self.uut.command_packet_received(transports.build_reliable_info_packet( sequence_number=0, ack_number=0, poll=True, port=0xabcd, information=b'info info info')) self.assertEqual(b'info info info', socket.receive(block=False)) def test_receive_on_unopened_port_doesnt_reach_socket(self): socket = self.uut.open_socket(0xabcd, timeout=0) self.uut.command_packet_received(transports.build_reliable_info_packet( sequence_number=0, ack_number=0, poll=True, port=0x3333, information=b'info')) with self.assertRaises(exceptions.ReceiveQueueEmpty): socket.receive(block=False) def test_receive_malformed_command_packet(self): self.uut.command_packet_received(b'garbage') self.uut.ncp.restart.assert_called_once_with() def test_receive_malformed_response_packet(self): self.uut.response_packet_received(b'garbage') self.uut.ncp.restart.assert_called_once_with() def test_transport_down_closes_link_sockets_and_ncp(self): self.uut.down() self.uut.command_socket.close.assert_called_with() self.uut.response_socket.close.assert_called_with() self.uut.ncp.down.assert_called_with() def test_pcmp_port_closed_message_closes_socket(self): socket = self.uut.open_socket(0xabcd, timeout=0) self.assertFalse(socket.closed) self.uut.command_packet_received(transports.build_reliable_info_packet( sequence_number=0, ack_number=0, poll=True, port=0x0001, information=b'\x81\xab\xcd')) self.assertTrue(socket.closed) def test_pcmp_port_closed_message_without_socket(self): self.uut.command_packet_received(transports.build_reliable_info_packet( sequence_number=0, ack_number=0, poll=True, port=0x0001, information=b'\x81\xaa\xaa')) class TestSocket(unittest.TestCase): def setUp(self): self.uut = transports.Socket(mock.Mock(), 1234) def test_empty_receive_queue(self): with self.assertRaises(exceptions.ReceiveQueueEmpty): self.uut.receive(block=False) def test_empty_receive_queue_blocking(self): with self.assertRaises(exceptions.ReceiveQueueEmpty): self.uut.receive(timeout=0.001) def test_receive(self): self.uut.on_receive(b'data') self.assertEqual(b'data', self.uut.receive(block=False)) with self.assertRaises(exceptions.ReceiveQueueEmpty): self.uut.receive(block=False) def test_receive_twice(self): self.uut.on_receive(b'one') self.uut.on_receive(b'two') self.assertEqual(b'one', self.uut.receive(block=False)) self.assertEqual(b'two', self.uut.receive(block=False)) def test_receive_interleaved(self): self.uut.on_receive(b'one') self.assertEqual(b'one', self.uut.receive(block=False)) self.uut.on_receive(b'two') self.assertEqual(b'two', self.uut.receive(block=False)) def test_send(self): self.uut.send(b'data') self.uut.transport.send.assert_called_once_with(1234, b'data') def test_close(self): self.uut.close() self.uut.transport.unregister_socket.assert_called_once_with(1234) def test_send_after_close_is_an_error(self): self.uut.close() with self.assertRaises(exceptions.SocketClosed): self.uut.send(b'data') def test_receive_after_close_is_an_error(self): self.uut.close() with self.assertRaises(exceptions.SocketClosed): self.uut.receive(block=False) def test_blocking_receive_after_close_is_an_error(self): self.uut.close() with self.assertRaises(exceptions.SocketClosed): self.uut.receive(timeout=0.001) def test_close_during_blocking_receive_aborts_the_receive(self): thread_started = threading.Event() result = [None] def test_thread(): thread_started.set() try: self.uut.receive(timeout=0.3) except Exception as e: result[0] = e thread = threading.Thread(target=test_thread) thread.daemon = True thread.start() assert thread_started.wait(timeout=0.5) self.uut.close() thread.join() self.assertIsInstance(result[0], exceptions.SocketClosed) def test_close_is_idempotent(self): self.uut.close() self.uut.close() self.assertEqual(1, self.uut.transport.unregister_socket.call_count)