# Copyright 2024 Google LLC # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import collections import struct import time from . import exceptions from . import socket class EraseCommand(object): command_type = 1 command_struct = struct.Struct(' 0) unsent = collections.OrderedDict() for offset in xrange(0, len(data), mtu): segment = data[offset:offset+mtu] assert(len(segment)) seg_address = address + offset unsent[seg_address] = WriteCommand(seg_address, segment) in_flight = collections.OrderedDict() retries = 0 while unsent or in_flight: try: while True: # Process ACKs (if any) ack = WriteResponse.parse( self.socket.receive(block=False)) try: cmd, _, _ = in_flight[ack.address] except KeyError: raise exceptions.WriteError( 'Received ACK for an unknown segment: ' '%#.08x' % ack.address) if len(cmd.data) != ack.length: raise exceptions.WriteError( 'ACK length %d != data length %d' % ( ack.length, len(cmd.data))) assert(ack.complete) del in_flight[ack.address] if progress_cb: progress_cb(True) except exceptions.ReceiveQueueEmpty: pass # Retry any in_flight writes where the ACK has timed out to_retry = [] timeout_time = time.time() - 0.5 for seg_address, (_, send_time, _) in in_flight.iteritems(): if send_time > timeout_time: # in_flight is an OrderedDict so iteration is in # chronological order. break to_retry.append(seg_address) retries += len(to_retry) for seg_address in to_retry: cmd, send_time, retry_count = in_flight[seg_address] del in_flight[seg_address] if retry_count >= max_retries: raise exceptions.WriteError( 'Segment %#.08x exceeded the max retry count (%d)' % ( seg_address, max_retries)) retry_count += 1 self.socket.send(cmd.packet) in_flight[seg_address] = (cmd, time.time(), retry_count) if progress_cb: progress_cb(False) # Send out fresh segments try: while len(in_flight) < max_in_flight: seg_address, cmd = unsent.popitem(last=False) self.socket.send(cmd.packet) in_flight[cmd.address] = (cmd, time.time(), 0) except KeyError: pass # Give other threads a chance to run time.sleep(0) return retries def _command_and_response(self, cmd, timeout=0.5): for attempt in xrange(5): self.socket.send(cmd.packet) try: packet = self.socket.receive(timeout=timeout) return cmd.parse_response(packet) except exceptions.ReceiveQueueEmpty: pass raise exceptions.CommandTimedOut def crc(self, address, length): cmd = CrcCommand(address, length) return self._command_and_response(cmd, timeout=1).crc def query_region_geometry(self, region): cmd = QueryFlashRegionCommand(region) return self._command_and_response(cmd) def finalize_region(self, region): cmd = FinalizeFlashRegionCommand(region) return self._command_and_response(cmd)