mirror of
https://github.com/google/pebble.git
synced 2025-03-20 19:01:21 +00:00
445 lines
16 KiB
Python
445 lines
16 KiB
Python
|
# Copyright 2024 Google LLC
|
||
|
#
|
||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||
|
# you may not use this file except in compliance with the License.
|
||
|
# You may obtain a copy of the License at
|
||
|
#
|
||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||
|
#
|
||
|
# Unless required by applicable law or agreed to in writing, software
|
||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||
|
# See the License for the specific language governing permissions and
|
||
|
# limitations under the License.
|
||
|
|
||
|
import logging
|
||
|
import Queue
|
||
|
import struct
|
||
|
import sys
|
||
|
import threading
|
||
|
import time
|
||
|
import traceback
|
||
|
import uuid
|
||
|
import weakref
|
||
|
|
||
|
from cobs import cobs
|
||
|
import serial
|
||
|
|
||
|
from . import exceptions
|
||
|
import stm32_crc
|
||
|
|
||
|
logger = logging.getLogger(__name__)
|
||
|
|
||
|
try:
|
||
|
import pyftdi.serialext
|
||
|
except ImportError:
|
||
|
pass
|
||
|
|
||
|
DBGSERIAL_PORT_SETTINGS = dict(baudrate=230400, timeout=0.1,
|
||
|
interCharTimeout=0.01)
|
||
|
|
||
|
|
||
|
def get_dbgserial_tty():
|
||
|
# Local import so that we only depend on this package if we're attempting
|
||
|
# to autodetect the TTY. This package isn't always available (e.g., MFG),
|
||
|
# so we don't want it to be required.
|
||
|
try:
|
||
|
import pebble_tty
|
||
|
return pebble_tty.find_dbgserial_tty()
|
||
|
except ImportError:
|
||
|
raise exceptions.TTYAutodetectionUnavailable
|
||
|
|
||
|
|
||
|
def frame_splitter(istream, size=1024, timeout=1, delimiter='\0'):
|
||
|
'''Returns an iterator which yields complete frames.'''
|
||
|
partial = []
|
||
|
start_time = time.time()
|
||
|
while not istream.closed:
|
||
|
data = istream.read(size)
|
||
|
logger.debug('frame_splitter: received %r', data)
|
||
|
while True:
|
||
|
left, delim, data = data.partition(delimiter)
|
||
|
if left:
|
||
|
partial.append(left)
|
||
|
if delim:
|
||
|
frame = ''.join(partial)
|
||
|
partial = []
|
||
|
if frame:
|
||
|
yield frame
|
||
|
if not data:
|
||
|
break
|
||
|
if timeout > 0 and time.time() > start_time + timeout:
|
||
|
yield
|
||
|
|
||
|
def decode_frame(frame):
|
||
|
'''Decodes a PULSE frame.
|
||
|
|
||
|
Returns a tuple (protocol, payload) of the decoded frame.
|
||
|
Raises FrameDecodeError if the frame is not valid.
|
||
|
'''
|
||
|
try:
|
||
|
data = cobs.decode(frame)
|
||
|
except cobs.DecodeError, e:
|
||
|
raise exceptions.FrameDecodeError(e.message)
|
||
|
if len(data) < 5:
|
||
|
raise exceptions.FrameDecodeError('frame too short')
|
||
|
fcs = struct.unpack('<I', data[-4:])[0]
|
||
|
crc = stm32_crc.crc32(data[:-4])
|
||
|
if fcs != crc:
|
||
|
raise exceptions.FrameDecodeError('FCS 0x%.08x != CRC 0x%.08x' % (fcs, crc))
|
||
|
protocol = ord(data[0])
|
||
|
return (protocol, data[1:-4])
|
||
|
|
||
|
def encode_frame(protocol, payload):
|
||
|
frame = struct.pack('<B', protocol)
|
||
|
frame += payload
|
||
|
fcs = stm32_crc.crc32(frame)
|
||
|
frame += struct.pack('<I', fcs)
|
||
|
return cobs.encode(frame)
|
||
|
|
||
|
|
||
|
class Connection(object):
|
||
|
'''A socket for sending and receiving datagrams over the PULSE serial
|
||
|
protocol.
|
||
|
'''
|
||
|
|
||
|
PROTOCOL_LLC = 0x01
|
||
|
|
||
|
LLC_LINK_OPEN_REQUEST = '\x01\x03\x08\x08\x08PULSEv1\r\n'
|
||
|
LLC_LINK_CLOSE_REQUEST = '\x03'
|
||
|
LLC_ECHO_REQUEST = '\x05'
|
||
|
LLC_CHANGE_BAUD = '\x07'
|
||
|
|
||
|
LLC_LINK_OPENED = 0x02
|
||
|
LLC_LINK_CLOSED = 0x04
|
||
|
LLC_ECHO_REPLY = 0x06
|
||
|
|
||
|
EXTENSIONS = {}
|
||
|
|
||
|
# Maximum round-trip time
|
||
|
rtt = 0.4
|
||
|
|
||
|
def __init__(self, iostream, infinite_reconnect=False):
|
||
|
self.infinite_reconnect = infinite_reconnect
|
||
|
self.iostream = iostream
|
||
|
self.closed = False
|
||
|
try:
|
||
|
self.initial_port_settings = self.iostream.getSettingsDict()
|
||
|
except AttributeError:
|
||
|
self.initial_port_settings = None
|
||
|
self.port_settings_altered = False
|
||
|
# Whether the link is open for sending.
|
||
|
self._link_open = threading.Event()
|
||
|
# Whether the link has been severed.
|
||
|
self._link_closed = threading.Event()
|
||
|
self.send_lock = threading.RLock()
|
||
|
self.echoes_inflight = weakref.WeakValueDictionary()
|
||
|
self.protocol_handlers = weakref.WeakValueDictionary()
|
||
|
self.receive_thread = threading.Thread(target=self.run_receive_thread)
|
||
|
self.receive_thread.daemon = True
|
||
|
self.receive_thread.start()
|
||
|
self._open_link()
|
||
|
|
||
|
self.keepalive_thread = threading.Thread(
|
||
|
target=self.run_keepalive_thread)
|
||
|
self.keepalive_thread.daemon = True
|
||
|
self.keepalive_thread.start()
|
||
|
|
||
|
# Instantiate and bind all known extensions
|
||
|
for name, factory in self.EXTENSIONS.iteritems():
|
||
|
setattr(self, name, factory(self))
|
||
|
|
||
|
@classmethod
|
||
|
def register_extension(cls, name, factory):
|
||
|
'''Register a PULSE connection extension.
|
||
|
|
||
|
When a Connection object is instantiated, the object returned by
|
||
|
factory(connection_object) is assigned to connection_object.<name>.
|
||
|
'''
|
||
|
try:
|
||
|
getattr(cls, name)
|
||
|
except AttributeError:
|
||
|
cls.EXTENSIONS[name] = factory
|
||
|
else:
|
||
|
raise ValueError('extension name %r clashes with existing attribute'
|
||
|
% (name,))
|
||
|
|
||
|
@classmethod
|
||
|
def open_dbgserial(cls, url=None, infinite_reconnect=False):
|
||
|
if url is None:
|
||
|
url = get_dbgserial_tty()
|
||
|
if url == "qemu":
|
||
|
url = 'socket://localhost:12345'
|
||
|
ser = serial.serial_for_url(url, **DBGSERIAL_PORT_SETTINGS)
|
||
|
|
||
|
if url.startswith('socket://'):
|
||
|
# Socket class for PySerial does some pointless buffering
|
||
|
# setting a very small timeout effectively negates it
|
||
|
ser._timeout = 0.00001
|
||
|
|
||
|
return cls(ser, infinite_reconnect=infinite_reconnect)
|
||
|
|
||
|
def __enter__(self):
|
||
|
return self
|
||
|
|
||
|
def __exit__(self, exc_type, exc_value, traceback):
|
||
|
self.close()
|
||
|
|
||
|
def __del__(self):
|
||
|
self.close()
|
||
|
|
||
|
def send(self, protocol, payload):
|
||
|
if self.closed:
|
||
|
raise exceptions.PulseError('I/O operation on closed connection')
|
||
|
frame = ''.join(('\0', encode_frame(protocol, payload), '\0'))
|
||
|
logger.debug('Connection: sending %r', frame)
|
||
|
with self.send_lock:
|
||
|
self.iostream.write(frame)
|
||
|
|
||
|
def run_receive_thread(self):
|
||
|
logger.debug('Connection: receive thread started')
|
||
|
receiver = frame_splitter(self.iostream, timeout=0)
|
||
|
while True:
|
||
|
try:
|
||
|
protocol, payload = decode_frame(next(receiver))
|
||
|
except exceptions.FrameDecodeError:
|
||
|
continue
|
||
|
except:
|
||
|
# Probably a PySerial exception complaining about reading from a
|
||
|
# closed port. Eat the exception and shut down the thread; users
|
||
|
# don't need to see the stack trace.
|
||
|
logger.debug('Connection: exception in receive thread:\n%s',
|
||
|
traceback.format_exc())
|
||
|
break
|
||
|
logger.debug('Connection:run_receive_thread: '
|
||
|
'protocol=%d payload=%r', protocol, payload)
|
||
|
if protocol == self.PROTOCOL_LLC: # LLC can't be overridden
|
||
|
self.llc_handler(payload)
|
||
|
continue
|
||
|
try:
|
||
|
handler = self.protocol_handlers[protocol]
|
||
|
except KeyError:
|
||
|
self.default_receiver(protocol, payload)
|
||
|
else:
|
||
|
handler.on_receive(payload)
|
||
|
logger.debug('Connection: receive thread exiting')
|
||
|
|
||
|
def default_receiver(self, protocol, frame):
|
||
|
logger.info('Connection:default_receiver: received frame '
|
||
|
'with protocol %d: %r', protocol, frame)
|
||
|
|
||
|
def register_protocol_handler(self, protocol, handler):
|
||
|
'''Register a handler for frames bearing the specified protocol number.
|
||
|
|
||
|
handler.on_receive(payload) is called for each frame received with the
|
||
|
protocol number.
|
||
|
|
||
|
Protocol handlers can be unregistered by calling this function with a
|
||
|
handler of None.
|
||
|
'''
|
||
|
if not handler:
|
||
|
try:
|
||
|
del self.protocol_handlers[protocol]
|
||
|
except KeyError:
|
||
|
pass
|
||
|
return
|
||
|
if protocol in self.protocol_handlers:
|
||
|
raise exceptions.ProtocolAlreadyRegistered(
|
||
|
'Protocol %d is already registered by %r' % (
|
||
|
protocol, self.protocol_handlers[protocol]))
|
||
|
if not hasattr(handler, 'on_receive'):
|
||
|
raise ValueError('%r does not have an on_receive method')
|
||
|
self.protocol_handlers[protocol] = handler
|
||
|
|
||
|
def llc_handler(self, frame):
|
||
|
opcode = ord(frame[0])
|
||
|
if opcode == self.LLC_LINK_OPENED:
|
||
|
# MTU and MRU are from the perspective of this side of the
|
||
|
# connection
|
||
|
version, mru, mtu, timeout = struct.unpack('<xBHHB', frame)
|
||
|
self.version = version
|
||
|
# The server reports the MTU inclusive of protocol number and FCS,
|
||
|
# but we only care about the maximum payload length.
|
||
|
self.mtu = mtu - 5
|
||
|
self.mru = mru
|
||
|
# Timeout is specified in deciseconds. Convert to seconds.
|
||
|
self.timeout = timeout / 10.0
|
||
|
self._link_closed.clear()
|
||
|
self._link_open.set()
|
||
|
elif opcode == self.LLC_LINK_CLOSED:
|
||
|
logger.info('PULSE connection closed.')
|
||
|
self._link_closed.set()
|
||
|
elif opcode == self.LLC_ECHO_REPLY:
|
||
|
self._on_echo_reply(frame[1:])
|
||
|
else:
|
||
|
logger.warning('Received LLC frame with unknown type %d: %r',
|
||
|
opcode, frame)
|
||
|
|
||
|
def run_keepalive_thread(self):
|
||
|
'''The keepalive thread monitors the link, reopening it if necessary.
|
||
|
'''
|
||
|
logger.debug('Connection: keepalive thread started')
|
||
|
OPEN, TEST_LIVENESS, RECONNECT = range(3)
|
||
|
state = OPEN
|
||
|
next_state = state
|
||
|
ping_attempts = 0
|
||
|
ping_wait = self.rtt
|
||
|
while True:
|
||
|
# Check whether the link is being closed from our side before
|
||
|
# trying to keep it alive.
|
||
|
if not self._link_open.is_set():
|
||
|
return
|
||
|
|
||
|
if state == OPEN:
|
||
|
time.sleep(1)
|
||
|
if self._link_closed.is_set():
|
||
|
next_state = RECONNECT
|
||
|
else:
|
||
|
next_state = TEST_LIVENESS
|
||
|
elif state == TEST_LIVENESS:
|
||
|
if ping_attempts < 3:
|
||
|
ping_attempts += 1
|
||
|
ping_wait *= 2 # Exponential backoff
|
||
|
if self.ping(ping_wait):
|
||
|
next_state = OPEN
|
||
|
else:
|
||
|
logger.info('No response to keepalive ping -- '
|
||
|
'strike %d', ping_attempts)
|
||
|
else:
|
||
|
logger.info('Connection: keepalive timed out.')
|
||
|
next_state = RECONNECT
|
||
|
elif state == RECONNECT:
|
||
|
# Lock out everyone from sending so that applications don't send
|
||
|
# to a connection that's in an indeterminate state.
|
||
|
with self.send_lock:
|
||
|
if self.port_settings_altered:
|
||
|
# Ensure that the server has timed out and reset its
|
||
|
# baud rate so we don't get into the bad situation where
|
||
|
# we try to reconnect at the default baud rate but the
|
||
|
# server is listening at a different rate, which is
|
||
|
# practically guaranteed to fail.
|
||
|
logger.info('Letting connection time out before '
|
||
|
'attempting to reconnect.')
|
||
|
time.sleep(self.timeout + self.rtt)
|
||
|
self._link_open.clear()
|
||
|
while not self._link_open.is_set():
|
||
|
try:
|
||
|
self._open_link()
|
||
|
except exceptions.PulseError as e:
|
||
|
logger.warning('Connection: reconnect failed. %s', e)
|
||
|
if not self.infinite_reconnect:
|
||
|
break
|
||
|
logger.warning('Will try again.')
|
||
|
logger.info('Backing off for a while before retrying.')
|
||
|
time.sleep(self.timeout + self.rtt)
|
||
|
else:
|
||
|
next_state = OPEN
|
||
|
else:
|
||
|
assert False, 'Invalid state %d' % state
|
||
|
|
||
|
if next_state != state:
|
||
|
if next_state == TEST_LIVENESS:
|
||
|
ping_attempts = 0
|
||
|
ping_wait = self.rtt
|
||
|
state = next_state
|
||
|
|
||
|
def _open_link(self):
|
||
|
self.closed = False
|
||
|
if self.initial_port_settings:
|
||
|
self.iostream.applySettingsDict(self.initial_port_settings)
|
||
|
for attempt in xrange(5):
|
||
|
logger.info('Opening link (attempt %d)...', attempt)
|
||
|
self.send(self.PROTOCOL_LLC, self.LLC_LINK_OPEN_REQUEST)
|
||
|
if self._link_open.wait(self.rtt):
|
||
|
logger.info('Established PULSE connection!')
|
||
|
logger.info('Version=%d MTU=%d MRU=%d Timeout=%.1f',
|
||
|
self.version, self.mtu, self.mru, self.timeout)
|
||
|
break
|
||
|
else:
|
||
|
self._link_closed.set()
|
||
|
self.closed = True
|
||
|
raise exceptions.PulseError('Could not establish connection')
|
||
|
|
||
|
def close(self):
|
||
|
self._link_open.clear()
|
||
|
if not self._link_closed.is_set():
|
||
|
for attempt in xrange(3):
|
||
|
self.send(self.PROTOCOL_LLC, self.LLC_LINK_CLOSE_REQUEST)
|
||
|
if self._link_closed.wait(self.rtt):
|
||
|
break
|
||
|
else:
|
||
|
logger.warning('Could not confirm link close.')
|
||
|
self._link_closed.set()
|
||
|
self.iostream.close()
|
||
|
self.closed = True
|
||
|
|
||
|
def ping(self, timeout=None):
|
||
|
if not timeout:
|
||
|
timeout = 2 * self.rtt
|
||
|
nonce = uuid.uuid4().bytes
|
||
|
is_received = threading.Event()
|
||
|
self.echoes_inflight[nonce] = is_received
|
||
|
self.send(self.PROTOCOL_LLC, self.LLC_ECHO_REQUEST + nonce)
|
||
|
return is_received.wait(timeout)
|
||
|
|
||
|
def _on_echo_reply(self, payload):
|
||
|
try:
|
||
|
receive_event = self.echoes_inflight[payload]
|
||
|
receive_event.set()
|
||
|
except KeyError:
|
||
|
pass
|
||
|
|
||
|
def change_baud_rate(self, new_baud):
|
||
|
# Fail fast if the IO object doesn't support changing the baud rate
|
||
|
old_baud = self.iostream.baudrate
|
||
|
self.send(self.PROTOCOL_LLC,
|
||
|
self.LLC_CHANGE_BAUD + struct.pack('<I', new_baud))
|
||
|
# Be extra sure that the message has been sent and it's safe to adjust
|
||
|
# the baud rate on the port.
|
||
|
time.sleep(0.1)
|
||
|
self.iostream.baudrate = new_baud
|
||
|
self.port_settings_altered = True
|
||
|
|
||
|
|
||
|
class ProtocolSocket(object):
|
||
|
'''A socket for sending and receiving datagrams of a single protocol over a
|
||
|
PULSE connection.
|
||
|
|
||
|
It is also an example of a Connection protocol handler implementation.
|
||
|
'''
|
||
|
|
||
|
def __init__(self, connection, protocol):
|
||
|
self.connection = connection
|
||
|
self.protocol = protocol
|
||
|
self.receive_queue = Queue.Queue()
|
||
|
self.connection.register_protocol_handler(protocol, self)
|
||
|
|
||
|
def on_receive(self, frame):
|
||
|
self.receive_queue.put(frame)
|
||
|
|
||
|
def receive(self, block=True, timeout=None):
|
||
|
try:
|
||
|
return self.receive_queue.get(block, timeout)
|
||
|
except Queue.Empty:
|
||
|
raise exceptions.ReceiveQueueEmpty
|
||
|
|
||
|
def send(self, frame):
|
||
|
self.connection.send(self.protocol, frame)
|
||
|
|
||
|
@property
|
||
|
def mtu(self):
|
||
|
return self.connection.mtu
|
||
|
|
||
|
|
||
|
if __name__ == '__main__':
|
||
|
logging.basicConfig(level=logging.INFO)
|
||
|
with Connection.open_dbgserial(sys.argv[1]) as sock:
|
||
|
sock.change_baud_rate(921600)
|
||
|
for _ in xrange(20):
|
||
|
time.sleep(0.5)
|
||
|
send_time = time.time()
|
||
|
if sock.ping():
|
||
|
print "Ping rtt=%.2f ms" % ((time.time() - send_time) * 1000)
|
||
|
else:
|
||
|
print "No echo"
|