mirror of
https://github.com/google/pebble.git
synced 2025-03-21 19:31:20 +00:00
301 lines
10 KiB
Python
301 lines
10 KiB
Python
# Copyright 2024 Google LLC
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
|
|
import collections
|
|
import struct
|
|
import time
|
|
|
|
from . import exceptions
|
|
from . import socket
|
|
|
|
|
|
class EraseCommand(object):
|
|
|
|
command_type = 1
|
|
command_struct = struct.Struct('<BII')
|
|
|
|
response_type = 128
|
|
response_struct = struct.Struct('<xII?')
|
|
Response = collections.namedtuple(
|
|
'EraseResponse', 'address length complete')
|
|
|
|
def __init__(self, address, length):
|
|
self.address = address
|
|
self.length = length
|
|
|
|
@property
|
|
def packet(self):
|
|
return self.command_struct.pack(
|
|
self.command_type, self.address, self.length)
|
|
|
|
def parse_response(self, response):
|
|
if ord(response[0]) != self.response_type:
|
|
raise exceptions.ResponseParseError(
|
|
'Unexpected response: %r' % response)
|
|
unpacked = self.Response._make(self.response_struct.unpack(response))
|
|
if unpacked.address != self.address or unpacked.length != self.length:
|
|
raise exceptions.ResponseParseError(
|
|
'Response does not match command: '
|
|
'address=%#.08x legnth=%d (expected %#.08x, %d)' % (
|
|
unpacked.address, unpacked.length, self.address,
|
|
self.length))
|
|
return unpacked
|
|
|
|
|
|
class WriteCommand(object):
|
|
|
|
command_type = 2
|
|
command_struct = struct.Struct('<BI')
|
|
header_len = command_struct.size
|
|
|
|
def __init__(self, address, data):
|
|
self.address = address
|
|
self.data = data
|
|
|
|
@property
|
|
def packet(self):
|
|
header = self.command_struct.pack(self.command_type, self.address)
|
|
return header + self.data
|
|
|
|
|
|
class WriteResponse(object):
|
|
|
|
response_type = 129
|
|
response_struct = struct.Struct('<xII?')
|
|
Response = collections.namedtuple(
|
|
'WriteResponse', 'address length complete')
|
|
|
|
@classmethod
|
|
def parse(cls, response):
|
|
if ord(response[0]) != cls.response_type:
|
|
raise exceptions.ResponseParseError(
|
|
'Unexpected response: %r' % response)
|
|
return cls.Response._make(cls.response_struct.unpack(response))
|
|
|
|
|
|
class CrcCommand(object):
|
|
|
|
command_type = 3
|
|
command_struct = struct.Struct('<BII')
|
|
|
|
response_type = 130
|
|
response_struct = struct.Struct('<xIII')
|
|
Response = collections.namedtuple('CrcResponse', 'address length crc')
|
|
|
|
def __init__(self, address, length):
|
|
self.address = address
|
|
self.length = length
|
|
|
|
@property
|
|
def packet(self):
|
|
return self.command_struct.pack(self.command_type, self.address,
|
|
self.length)
|
|
|
|
def parse_response(self, response):
|
|
if ord(response[0]) != self.response_type:
|
|
raise exceptions.ResponseParseError(
|
|
'Unexpected response: %r' % response)
|
|
unpacked = self.Response._make(self.response_struct.unpack(response))
|
|
if unpacked.address != self.address or unpacked.length != self.length:
|
|
raise exceptions.ResponseParseError(
|
|
'Response does not match command: '
|
|
'address=%#.08x legnth=%d (expected %#.08x, %d)' % (
|
|
unpacked.address, unpacked.length, self.address,
|
|
self.length))
|
|
return unpacked
|
|
|
|
|
|
class QueryFlashRegionCommand(object):
|
|
|
|
command_type = 4
|
|
command_struct = struct.Struct('<BB')
|
|
|
|
REGION_PRF = 1
|
|
REGION_SYSTEM_RESOURCES = 2
|
|
|
|
response_type = 131
|
|
response_struct = struct.Struct('<xBII')
|
|
Response = collections.namedtuple(
|
|
'FlashRegionGeometry', 'region address length')
|
|
|
|
def __init__(self, region):
|
|
self.region = region
|
|
|
|
@property
|
|
def packet(self):
|
|
return self.command_struct.pack(self.command_type, self.region)
|
|
|
|
def parse_response(self, response):
|
|
if ord(response[0]) != self.response_type:
|
|
raise exceptions.ResponseParseError(
|
|
'Unexpected response: %r' % response)
|
|
unpacked = self.Response._make(self.response_struct.unpack(response))
|
|
if unpacked.address == 0 and unpacked.length == 0:
|
|
raise exceptions.RegionDoesNotExist(self.region)
|
|
return unpacked
|
|
|
|
|
|
class FinalizeFlashRegionCommand(object):
|
|
|
|
command_type = 5
|
|
command_struct = struct.Struct('<BB')
|
|
|
|
response_type = 132
|
|
response_struct = struct.Struct('<xB')
|
|
|
|
def __init__(self, region):
|
|
self.region = region
|
|
|
|
@property
|
|
def packet(self):
|
|
return self.command_struct.pack(self.command_type, self.region)
|
|
|
|
def parse_response(self, response):
|
|
if ord(response[0]) != self.response_type:
|
|
raise exceptions.ResponseParseError(
|
|
'Unexpected response: %r' % response)
|
|
region, = self.response_struct.unpack(response)
|
|
if region != self.region:
|
|
raise exceptions.ResponseParseError(
|
|
'Response does not match command: '
|
|
'response is for region %d (expected %d)' % (
|
|
region, self.region))
|
|
|
|
|
|
class FlashImagingProtocol(object):
|
|
|
|
PROTOCOL_NUMBER = 0x02
|
|
|
|
RESP_BAD_CMD = 192
|
|
RESP_INTERNAL_ERROR = 193
|
|
|
|
REGION_PRF = QueryFlashRegionCommand.REGION_PRF
|
|
REGION_SYSTEM_RESOURCES = QueryFlashRegionCommand.REGION_SYSTEM_RESOURCES
|
|
|
|
def __init__(self, connection):
|
|
self.socket = socket.ProtocolSocket(connection,
|
|
self.PROTOCOL_NUMBER)
|
|
|
|
def erase(self, address, length):
|
|
cmd = EraseCommand(address, length)
|
|
ack_received = False
|
|
retries = 0
|
|
while retries < 10:
|
|
if not ack_received:
|
|
self.socket.send(cmd.packet)
|
|
try:
|
|
packet = self.socket.receive(timeout=5 if ack_received else 1.5)
|
|
response = cmd.parse_response(packet)
|
|
ack_received = True
|
|
if response.complete:
|
|
return
|
|
except exceptions.ReceiveQueueEmpty:
|
|
ack_received = False
|
|
retries += 1
|
|
continue
|
|
raise exceptions.CommandTimedOut
|
|
|
|
def write(self, address, data, max_retries=5, max_in_flight=5,
|
|
progress_cb=None):
|
|
mtu = self.socket.mtu - WriteCommand.header_len
|
|
assert(mtu > 0)
|
|
unsent = collections.OrderedDict()
|
|
for offset in xrange(0, len(data), mtu):
|
|
segment = data[offset:offset+mtu]
|
|
assert(len(segment))
|
|
seg_address = address + offset
|
|
unsent[seg_address] = WriteCommand(seg_address, segment)
|
|
|
|
in_flight = collections.OrderedDict()
|
|
retries = 0
|
|
while unsent or in_flight:
|
|
try:
|
|
while True:
|
|
# Process ACKs (if any)
|
|
ack = WriteResponse.parse(
|
|
self.socket.receive(block=False))
|
|
try:
|
|
cmd, _, _ = in_flight[ack.address]
|
|
except KeyError:
|
|
raise exceptions.WriteError(
|
|
'Received ACK for an unknown segment: '
|
|
'%#.08x' % ack.address)
|
|
if len(cmd.data) != ack.length:
|
|
raise exceptions.WriteError(
|
|
'ACK length %d != data length %d' % (
|
|
ack.length, len(cmd.data)))
|
|
assert(ack.complete)
|
|
del in_flight[ack.address]
|
|
if progress_cb:
|
|
progress_cb(True)
|
|
except exceptions.ReceiveQueueEmpty:
|
|
pass
|
|
|
|
# Retry any in_flight writes where the ACK has timed out
|
|
to_retry = []
|
|
timeout_time = time.time() - 0.5
|
|
for seg_address, (_, send_time, _) in in_flight.iteritems():
|
|
if send_time > timeout_time:
|
|
# in_flight is an OrderedDict so iteration is in
|
|
# chronological order.
|
|
break
|
|
to_retry.append(seg_address)
|
|
retries += len(to_retry)
|
|
for seg_address in to_retry:
|
|
cmd, send_time, retry_count = in_flight[seg_address]
|
|
del in_flight[seg_address]
|
|
if retry_count >= max_retries:
|
|
raise exceptions.WriteError(
|
|
'Segment %#.08x exceeded the max retry count (%d)' % (
|
|
seg_address, max_retries))
|
|
retry_count += 1
|
|
self.socket.send(cmd.packet)
|
|
in_flight[seg_address] = (cmd, time.time(), retry_count)
|
|
if progress_cb:
|
|
progress_cb(False)
|
|
|
|
# Send out fresh segments
|
|
try:
|
|
while len(in_flight) < max_in_flight:
|
|
seg_address, cmd = unsent.popitem(last=False)
|
|
self.socket.send(cmd.packet)
|
|
in_flight[cmd.address] = (cmd, time.time(), 0)
|
|
except KeyError:
|
|
pass
|
|
|
|
# Give other threads a chance to run
|
|
time.sleep(0)
|
|
return retries
|
|
|
|
def _command_and_response(self, cmd, timeout=0.5):
|
|
for attempt in xrange(5):
|
|
self.socket.send(cmd.packet)
|
|
try:
|
|
packet = self.socket.receive(timeout=timeout)
|
|
return cmd.parse_response(packet)
|
|
except exceptions.ReceiveQueueEmpty:
|
|
pass
|
|
raise exceptions.CommandTimedOut
|
|
|
|
def crc(self, address, length):
|
|
cmd = CrcCommand(address, length)
|
|
return self._command_and_response(cmd, timeout=1).crc
|
|
|
|
def query_region_geometry(self, region):
|
|
cmd = QueryFlashRegionCommand(region)
|
|
return self._command_and_response(cmd)
|
|
|
|
def finalize_region(self, region):
|
|
cmd = FinalizeFlashRegionCommand(region)
|
|
return self._command_and_response(cmd)
|