github.com/westcoastroms/westcoastroms-build@v0.0.0-20190928114312-2350e5a73030/build/make/tools/releasetools/blockimgdiff.py (about)

     1  # Copyright (C) 2014 The Android Open Source Project
     2  #
     3  # Licensed under the Apache License, Version 2.0 (the "License");
     4  # you may not use this file except in compliance with the License.
     5  # You may obtain a copy of the License at
     6  #
     7  #      http://www.apache.org/licenses/LICENSE-2.0
     8  #
     9  # Unless required by applicable law or agreed to in writing, software
    10  # distributed under the License is distributed on an "AS IS" BASIS,
    11  # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    12  # See the License for the specific language governing permissions and
    13  # limitations under the License.
    14  
    15  from __future__ import print_function
    16  
    17  import array
    18  import copy
    19  import functools
    20  import heapq
    21  import itertools
    22  import multiprocessing
    23  import os
    24  import os.path
    25  import re
    26  import subprocess
    27  import sys
    28  import threading
    29  from collections import deque, OrderedDict
    30  from hashlib import sha1
    31  
    32  import common
    33  from rangelib import RangeSet
    34  
    35  
    36  __all__ = ["EmptyImage", "DataImage", "BlockImageDiff"]
    37  
    38  
    39  def compute_patch(srcfile, tgtfile, imgdiff=False):
    40    patchfile = common.MakeTempFile(prefix='patch-')
    41  
    42    cmd = ['imgdiff', '-z'] if imgdiff else ['bsdiff']
    43    cmd.extend([srcfile, tgtfile, patchfile])
    44  
    45    # Don't dump the bsdiff/imgdiff commands, which are not useful for the case
    46    # here, since they contain temp filenames only.
    47    p = common.Run(cmd, verbose=False, stdout=subprocess.PIPE,
    48                   stderr=subprocess.STDOUT)
    49    output, _ = p.communicate()
    50  
    51    if p.returncode != 0:
    52      raise ValueError(output)
    53  
    54    with open(patchfile, 'rb') as f:
    55      return f.read()
    56  
    57  
    58  class Image(object):
    59    def RangeSha1(self, ranges):
    60      raise NotImplementedError
    61  
    62    def ReadRangeSet(self, ranges):
    63      raise NotImplementedError
    64  
    65    def TotalSha1(self, include_clobbered_blocks=False):
    66      raise NotImplementedError
    67  
    68    def WriteRangeDataToFd(self, ranges, fd):
    69      raise NotImplementedError
    70  
    71  
    72  class EmptyImage(Image):
    73    """A zero-length image."""
    74  
    75    def __init__(self):
    76      self.blocksize = 4096
    77      self.care_map = RangeSet()
    78      self.clobbered_blocks = RangeSet()
    79      self.extended = RangeSet()
    80      self.total_blocks = 0
    81      self.file_map = {}
    82  
    83    def RangeSha1(self, ranges):
    84      return sha1().hexdigest()
    85  
    86    def ReadRangeSet(self, ranges):
    87      return ()
    88  
    89    def TotalSha1(self, include_clobbered_blocks=False):
    90      # EmptyImage always carries empty clobbered_blocks, so
    91      # include_clobbered_blocks can be ignored.
    92      assert self.clobbered_blocks.size() == 0
    93      return sha1().hexdigest()
    94  
    95    def WriteRangeDataToFd(self, ranges, fd):
    96      raise ValueError("Can't write data from EmptyImage to file")
    97  
    98  
    99  class DataImage(Image):
   100    """An image wrapped around a single string of data."""
   101  
   102    def __init__(self, data, trim=False, pad=False):
   103      self.data = data
   104      self.blocksize = 4096
   105  
   106      assert not (trim and pad)
   107  
   108      partial = len(self.data) % self.blocksize
   109      padded = False
   110      if partial > 0:
   111        if trim:
   112          self.data = self.data[:-partial]
   113        elif pad:
   114          self.data += '\0' * (self.blocksize - partial)
   115          padded = True
   116        else:
   117          raise ValueError(("data for DataImage must be multiple of %d bytes "
   118                            "unless trim or pad is specified") %
   119                           (self.blocksize,))
   120  
   121      assert len(self.data) % self.blocksize == 0
   122  
   123      self.total_blocks = len(self.data) / self.blocksize
   124      self.care_map = RangeSet(data=(0, self.total_blocks))
   125      # When the last block is padded, we always write the whole block even for
   126      # incremental OTAs. Because otherwise the last block may get skipped if
   127      # unchanged for an incremental, but would fail the post-install
   128      # verification if it has non-zero contents in the padding bytes.
   129      # Bug: 23828506
   130      if padded:
   131        clobbered_blocks = [self.total_blocks-1, self.total_blocks]
   132      else:
   133        clobbered_blocks = []
   134      self.clobbered_blocks = clobbered_blocks
   135      self.extended = RangeSet()
   136  
   137      zero_blocks = []
   138      nonzero_blocks = []
   139      reference = '\0' * self.blocksize
   140  
   141      for i in range(self.total_blocks-1 if padded else self.total_blocks):
   142        d = self.data[i*self.blocksize : (i+1)*self.blocksize]
   143        if d == reference:
   144          zero_blocks.append(i)
   145          zero_blocks.append(i+1)
   146        else:
   147          nonzero_blocks.append(i)
   148          nonzero_blocks.append(i+1)
   149  
   150      assert zero_blocks or nonzero_blocks or clobbered_blocks
   151  
   152      self.file_map = dict()
   153      if zero_blocks:
   154        self.file_map["__ZERO"] = RangeSet(data=zero_blocks)
   155      if nonzero_blocks:
   156        self.file_map["__NONZERO"] = RangeSet(data=nonzero_blocks)
   157      if clobbered_blocks:
   158        self.file_map["__COPY"] = RangeSet(data=clobbered_blocks)
   159  
   160    def _GetRangeData(self, ranges):
   161      for s, e in ranges:
   162        yield self.data[s*self.blocksize:e*self.blocksize]
   163  
   164    def RangeSha1(self, ranges):
   165      h = sha1()
   166      for data in self._GetRangeData(ranges):
   167        h.update(data)
   168      return h.hexdigest()
   169  
   170    def ReadRangeSet(self, ranges):
   171      return [self._GetRangeData(ranges)]
   172  
   173    def TotalSha1(self, include_clobbered_blocks=False):
   174      if not include_clobbered_blocks:
   175        return self.RangeSha1(self.care_map.subtract(self.clobbered_blocks))
   176      else:
   177        return sha1(self.data).hexdigest()
   178  
   179    def WriteRangeDataToFd(self, ranges, fd):
   180      for data in self._GetRangeData(ranges):
   181        fd.write(data)
   182  
   183  
   184  class Transfer(object):
   185    def __init__(self, tgt_name, src_name, tgt_ranges, src_ranges, tgt_sha1,
   186                 src_sha1, style, by_id):
   187      self.tgt_name = tgt_name
   188      self.src_name = src_name
   189      self.tgt_ranges = tgt_ranges
   190      self.src_ranges = src_ranges
   191      self.tgt_sha1 = tgt_sha1
   192      self.src_sha1 = src_sha1
   193      self.style = style
   194  
   195      # We use OrderedDict rather than dict so that the output is repeatable;
   196      # otherwise it would depend on the hash values of the Transfer objects.
   197      self.goes_before = OrderedDict()
   198      self.goes_after = OrderedDict()
   199  
   200      self.stash_before = []
   201      self.use_stash = []
   202  
   203      self.id = len(by_id)
   204      by_id.append(self)
   205  
   206      self._patch = None
   207  
   208    @property
   209    def patch(self):
   210      return self._patch
   211  
   212    @patch.setter
   213    def patch(self, patch):
   214      if patch:
   215        assert self.style == "diff"
   216      self._patch = patch
   217  
   218    def NetStashChange(self):
   219      return (sum(sr.size() for (_, sr) in self.stash_before) -
   220              sum(sr.size() for (_, sr) in self.use_stash))
   221  
   222    def ConvertToNew(self):
   223      assert self.style != "new"
   224      self.use_stash = []
   225      self.style = "new"
   226      self.src_ranges = RangeSet()
   227      self.patch = None
   228  
   229    def __str__(self):
   230      return (str(self.id) + ": <" + str(self.src_ranges) + " " + self.style +
   231              " to " + str(self.tgt_ranges) + ">")
   232  
   233  
   234  @functools.total_ordering
   235  class HeapItem(object):
   236    def __init__(self, item):
   237      self.item = item
   238      # Negate the score since python's heap is a min-heap and we want the
   239      # maximum score.
   240      self.score = -item.score
   241  
   242    def clear(self):
   243      self.item = None
   244  
   245    def __bool__(self):
   246      return self.item is not None
   247  
   248    # Python 2 uses __nonzero__, while Python 3 uses __bool__.
   249    __nonzero__ = __bool__
   250  
   251    # The rest operations are generated by functools.total_ordering decorator.
   252    def __eq__(self, other):
   253      return self.score == other.score
   254  
   255    def __le__(self, other):
   256      return self.score <= other.score
   257  
   258  
   259  class ImgdiffStats(object):
   260    """A class that collects imgdiff stats.
   261  
   262    It keeps track of the files that will be applied imgdiff while generating
   263    BlockImageDiff. It also logs the ones that cannot use imgdiff, with specific
   264    reasons. The stats is only meaningful when imgdiff not being disabled by the
   265    caller of BlockImageDiff. In addition, only files with supported types
   266    (BlockImageDiff.FileTypeSupportedByImgdiff()) are allowed to be logged.
   267    """
   268  
   269    USED_IMGDIFF = "APK files diff'd with imgdiff"
   270    USED_IMGDIFF_LARGE_APK = "Large APK files split and diff'd with imgdiff"
   271  
   272    # Reasons for not applying imgdiff on APKs.
   273    SKIPPED_TRIMMED = "Not used imgdiff due to trimmed RangeSet"
   274    SKIPPED_NONMONOTONIC = "Not used imgdiff due to having non-monotonic ranges"
   275    SKIPPED_SHARED_BLOCKS = "Not used imgdiff due to using shared blocks"
   276    SKIPPED_INCOMPLETE = "Not used imgdiff due to incomplete RangeSet"
   277  
   278    # The list of valid reasons, which will also be the dumped order in a report.
   279    REASONS = (
   280        USED_IMGDIFF,
   281        USED_IMGDIFF_LARGE_APK,
   282        SKIPPED_TRIMMED,
   283        SKIPPED_NONMONOTONIC,
   284        SKIPPED_SHARED_BLOCKS,
   285        SKIPPED_INCOMPLETE,
   286    )
   287  
   288    def  __init__(self):
   289      self.stats = {}
   290  
   291    def Log(self, filename, reason):
   292      """Logs why imgdiff can or cannot be applied to the given filename.
   293  
   294      Args:
   295        filename: The filename string.
   296        reason: One of the reason constants listed in REASONS.
   297  
   298      Raises:
   299        AssertionError: On unsupported filetypes or invalid reason.
   300      """
   301      assert BlockImageDiff.FileTypeSupportedByImgdiff(filename)
   302      assert reason in self.REASONS
   303  
   304      if reason not in self.stats:
   305        self.stats[reason] = set()
   306      self.stats[reason].add(filename)
   307  
   308    def Report(self):
   309      """Prints a report of the collected imgdiff stats."""
   310  
   311      def print_header(header, separator):
   312        print(header)
   313        print(separator * len(header) + '\n')
   314  
   315      print_header('  Imgdiff Stats Report  ', '=')
   316      for key in self.REASONS:
   317        if key not in self.stats:
   318          continue
   319        values = self.stats[key]
   320        section_header = ' {} (count: {}) '.format(key, len(values))
   321        print_header(section_header, '-')
   322        print(''.join(['  {}\n'.format(name) for name in values]))
   323  
   324  
   325  # BlockImageDiff works on two image objects.  An image object is
   326  # anything that provides the following attributes:
   327  #
   328  #    blocksize: the size in bytes of a block, currently must be 4096.
   329  #
   330  #    total_blocks: the total size of the partition/image, in blocks.
   331  #
   332  #    care_map: a RangeSet containing which blocks (in the range [0,
   333  #      total_blocks) we actually care about; i.e. which blocks contain
   334  #      data.
   335  #
   336  #    file_map: a dict that partitions the blocks contained in care_map
   337  #      into smaller domains that are useful for doing diffs on.
   338  #      (Typically a domain is a file, and the key in file_map is the
   339  #      pathname.)
   340  #
   341  #    clobbered_blocks: a RangeSet containing which blocks contain data
   342  #      but may be altered by the FS. They need to be excluded when
   343  #      verifying the partition integrity.
   344  #
   345  #    ReadRangeSet(): a function that takes a RangeSet and returns the
   346  #      data contained in the image blocks of that RangeSet.  The data
   347  #      is returned as a list or tuple of strings; concatenating the
   348  #      elements together should produce the requested data.
   349  #      Implementations are free to break up the data into list/tuple
   350  #      elements in any way that is convenient.
   351  #
   352  #    RangeSha1(): a function that returns (as a hex string) the SHA-1
   353  #      hash of all the data in the specified range.
   354  #
   355  #    TotalSha1(): a function that returns (as a hex string) the SHA-1
   356  #      hash of all the data in the image (ie, all the blocks in the
   357  #      care_map minus clobbered_blocks, or including the clobbered
   358  #      blocks if include_clobbered_blocks is True).
   359  #
   360  # When creating a BlockImageDiff, the src image may be None, in which
   361  # case the list of transfers produced will never read from the
   362  # original image.
   363  
   364  class BlockImageDiff(object):
   365    def __init__(self, tgt, src=None, threads=None, version=4,
   366                 disable_imgdiff=False):
   367      if threads is None:
   368        threads = multiprocessing.cpu_count() // 2
   369        if threads == 0:
   370          threads = 1
   371      self.threads = threads
   372      self.version = version
   373      self.transfers = []
   374      self.src_basenames = {}
   375      self.src_numpatterns = {}
   376      self._max_stashed_size = 0
   377      self.touched_src_ranges = RangeSet()
   378      self.touched_src_sha1 = None
   379      self.disable_imgdiff = disable_imgdiff
   380      self.imgdiff_stats = ImgdiffStats() if not disable_imgdiff else None
   381  
   382      assert version in (3, 4)
   383  
   384      self.tgt = tgt
   385      if src is None:
   386        src = EmptyImage()
   387      self.src = src
   388  
   389      # The updater code that installs the patch always uses 4k blocks.
   390      assert tgt.blocksize == 4096
   391      assert src.blocksize == 4096
   392  
   393      # The range sets in each filemap should comprise a partition of
   394      # the care map.
   395      self.AssertPartition(src.care_map, src.file_map.values())
   396      self.AssertPartition(tgt.care_map, tgt.file_map.values())
   397  
   398    @property
   399    def max_stashed_size(self):
   400      return self._max_stashed_size
   401  
   402    @staticmethod
   403    def FileTypeSupportedByImgdiff(filename):
   404      """Returns whether the file type is supported by imgdiff."""
   405      return filename.lower().endswith(('.apk', '.jar', '.zip'))
   406  
   407    def CanUseImgdiff(self, name, tgt_ranges, src_ranges, large_apk=False):
   408      """Checks whether we can apply imgdiff for the given RangeSets.
   409  
   410      For files in ZIP format (e.g., APKs, JARs, etc.) we would like to use
   411      'imgdiff -z' if possible. Because it usually produces significantly smaller
   412      patches than bsdiff.
   413  
   414      This is permissible if all of the following conditions hold.
   415        - The imgdiff hasn't been disabled by the caller (e.g. squashfs);
   416        - The file type is supported by imgdiff;
   417        - The source and target blocks are monotonic (i.e. the data is stored with
   418          blocks in increasing order);
   419        - Both files don't contain shared blocks;
   420        - Both files have complete lists of blocks;
   421        - We haven't removed any blocks from the source set.
   422  
   423      If all these conditions are satisfied, concatenating all the blocks in the
   424      RangeSet in order will produce a valid ZIP file (plus possibly extra zeros
   425      in the last block). imgdiff is fine with extra zeros at the end of the file.
   426  
   427      Args:
   428        name: The filename to be diff'd.
   429        tgt_ranges: The target RangeSet.
   430        src_ranges: The source RangeSet.
   431        large_apk: Whether this is to split a large APK.
   432  
   433      Returns:
   434        A boolean result.
   435      """
   436      if self.disable_imgdiff or not self.FileTypeSupportedByImgdiff(name):
   437        return False
   438  
   439      if not tgt_ranges.monotonic or not src_ranges.monotonic:
   440        self.imgdiff_stats.Log(name, ImgdiffStats.SKIPPED_NONMONOTONIC)
   441        return False
   442  
   443      if (tgt_ranges.extra.get('uses_shared_blocks') or
   444          src_ranges.extra.get('uses_shared_blocks')):
   445        self.imgdiff_stats.Log(name, ImgdiffStats.SKIPPED_SHARED_BLOCKS)
   446        return False
   447  
   448      if tgt_ranges.extra.get('incomplete') or src_ranges.extra.get('incomplete'):
   449        self.imgdiff_stats.Log(name, ImgdiffStats.SKIPPED_INCOMPLETE)
   450        return False
   451  
   452      if tgt_ranges.extra.get('trimmed') or src_ranges.extra.get('trimmed'):
   453        self.imgdiff_stats.Log(name, ImgdiffStats.SKIPPED_TRIMMED)
   454        return False
   455  
   456      reason = (ImgdiffStats.USED_IMGDIFF_LARGE_APK if large_apk
   457                else ImgdiffStats.USED_IMGDIFF)
   458      self.imgdiff_stats.Log(name, reason)
   459      return True
   460  
   461    def Compute(self, prefix):
   462      # When looking for a source file to use as the diff input for a
   463      # target file, we try:
   464      #   1) an exact path match if available, otherwise
   465      #   2) a exact basename match if available, otherwise
   466      #   3) a basename match after all runs of digits are replaced by
   467      #      "#" if available, otherwise
   468      #   4) we have no source for this target.
   469      self.AbbreviateSourceNames()
   470      self.FindTransfers()
   471  
   472      # Find the ordering dependencies among transfers (this is O(n^2)
   473      # in the number of transfers).
   474      self.GenerateDigraph()
   475      # Find a sequence of transfers that satisfies as many ordering
   476      # dependencies as possible (heuristically).
   477      self.FindVertexSequence()
   478      # Fix up the ordering dependencies that the sequence didn't
   479      # satisfy.
   480      self.ReverseBackwardEdges()
   481      self.ImproveVertexSequence()
   482  
   483      # Ensure the runtime stash size is under the limit.
   484      if common.OPTIONS.cache_size is not None:
   485        self.ReviseStashSize()
   486  
   487      # Double-check our work.
   488      self.AssertSequenceGood()
   489      self.AssertSha1Good()
   490  
   491      self.ComputePatches(prefix)
   492      self.WriteTransfers(prefix)
   493  
   494      # Report the imgdiff stats.
   495      if common.OPTIONS.verbose and not self.disable_imgdiff:
   496        self.imgdiff_stats.Report()
   497  
   498    def WriteTransfers(self, prefix):
   499      def WriteSplitTransfers(out, style, target_blocks):
   500        """Limit the size of operand in command 'new' and 'zero' to 1024 blocks.
   501  
   502        This prevents the target size of one command from being too large; and
   503        might help to avoid fsync errors on some devices."""
   504  
   505        assert style == "new" or style == "zero"
   506        blocks_limit = 1024
   507        total = 0
   508        while target_blocks:
   509          blocks_to_write = target_blocks.first(blocks_limit)
   510          out.append("%s %s\n" % (style, blocks_to_write.to_string_raw()))
   511          total += blocks_to_write.size()
   512          target_blocks = target_blocks.subtract(blocks_to_write)
   513        return total
   514  
   515      out = []
   516      total = 0
   517  
   518      # In BBOTA v3+, it uses the hash of the stashed blocks as the stash slot
   519      # id. 'stashes' records the map from 'hash' to the ref count. The stash
   520      # will be freed only if the count decrements to zero.
   521      stashes = {}
   522      stashed_blocks = 0
   523      max_stashed_blocks = 0
   524  
   525      for xf in self.transfers:
   526  
   527        for _, sr in xf.stash_before:
   528          sh = self.src.RangeSha1(sr)
   529          if sh in stashes:
   530            stashes[sh] += 1
   531          else:
   532            stashes[sh] = 1
   533            stashed_blocks += sr.size()
   534            self.touched_src_ranges = self.touched_src_ranges.union(sr)
   535            out.append("stash %s %s\n" % (sh, sr.to_string_raw()))
   536  
   537        if stashed_blocks > max_stashed_blocks:
   538          max_stashed_blocks = stashed_blocks
   539  
   540        free_string = []
   541        free_size = 0
   542  
   543        #   <# blocks> <src ranges>
   544        #     OR
   545        #   <# blocks> <src ranges> <src locs> <stash refs...>
   546        #     OR
   547        #   <# blocks> - <stash refs...>
   548  
   549        size = xf.src_ranges.size()
   550        src_str_buffer = [str(size)]
   551  
   552        unstashed_src_ranges = xf.src_ranges
   553        mapped_stashes = []
   554        for _, sr in xf.use_stash:
   555          unstashed_src_ranges = unstashed_src_ranges.subtract(sr)
   556          sh = self.src.RangeSha1(sr)
   557          sr = xf.src_ranges.map_within(sr)
   558          mapped_stashes.append(sr)
   559          assert sh in stashes
   560          src_str_buffer.append("%s:%s" % (sh, sr.to_string_raw()))
   561          stashes[sh] -= 1
   562          if stashes[sh] == 0:
   563            free_string.append("free %s\n" % (sh,))
   564            free_size += sr.size()
   565            stashes.pop(sh)
   566  
   567        if unstashed_src_ranges:
   568          src_str_buffer.insert(1, unstashed_src_ranges.to_string_raw())
   569          if xf.use_stash:
   570            mapped_unstashed = xf.src_ranges.map_within(unstashed_src_ranges)
   571            src_str_buffer.insert(2, mapped_unstashed.to_string_raw())
   572            mapped_stashes.append(mapped_unstashed)
   573            self.AssertPartition(RangeSet(data=(0, size)), mapped_stashes)
   574        else:
   575          src_str_buffer.insert(1, "-")
   576          self.AssertPartition(RangeSet(data=(0, size)), mapped_stashes)
   577  
   578        src_str = " ".join(src_str_buffer)
   579  
   580        # version 3+:
   581        #   zero <rangeset>
   582        #   new <rangeset>
   583        #   erase <rangeset>
   584        #   bsdiff patchstart patchlen srchash tgthash <tgt rangeset> <src_str>
   585        #   imgdiff patchstart patchlen srchash tgthash <tgt rangeset> <src_str>
   586        #   move hash <tgt rangeset> <src_str>
   587  
   588        tgt_size = xf.tgt_ranges.size()
   589  
   590        if xf.style == "new":
   591          assert xf.tgt_ranges
   592          assert tgt_size == WriteSplitTransfers(out, xf.style, xf.tgt_ranges)
   593          total += tgt_size
   594        elif xf.style == "move":
   595          assert xf.tgt_ranges
   596          assert xf.src_ranges.size() == tgt_size
   597          if xf.src_ranges != xf.tgt_ranges:
   598            # take into account automatic stashing of overlapping blocks
   599            if xf.src_ranges.overlaps(xf.tgt_ranges):
   600              temp_stash_usage = stashed_blocks + xf.src_ranges.size()
   601              if temp_stash_usage > max_stashed_blocks:
   602                max_stashed_blocks = temp_stash_usage
   603  
   604            self.touched_src_ranges = self.touched_src_ranges.union(
   605                xf.src_ranges)
   606  
   607            out.append("%s %s %s %s\n" % (
   608                xf.style,
   609                xf.tgt_sha1,
   610                xf.tgt_ranges.to_string_raw(), src_str))
   611            total += tgt_size
   612        elif xf.style in ("bsdiff", "imgdiff"):
   613          assert xf.tgt_ranges
   614          assert xf.src_ranges
   615          # take into account automatic stashing of overlapping blocks
   616          if xf.src_ranges.overlaps(xf.tgt_ranges):
   617            temp_stash_usage = stashed_blocks + xf.src_ranges.size()
   618            if temp_stash_usage > max_stashed_blocks:
   619              max_stashed_blocks = temp_stash_usage
   620  
   621          self.touched_src_ranges = self.touched_src_ranges.union(xf.src_ranges)
   622  
   623          out.append("%s %d %d %s %s %s %s\n" % (
   624              xf.style,
   625              xf.patch_start, xf.patch_len,
   626              xf.src_sha1,
   627              xf.tgt_sha1,
   628              xf.tgt_ranges.to_string_raw(), src_str))
   629          total += tgt_size
   630        elif xf.style == "zero":
   631          assert xf.tgt_ranges
   632          to_zero = xf.tgt_ranges.subtract(xf.src_ranges)
   633          assert WriteSplitTransfers(out, xf.style, to_zero) == to_zero.size()
   634          total += to_zero.size()
   635        else:
   636          raise ValueError("unknown transfer style '%s'\n" % xf.style)
   637  
   638        if free_string:
   639          out.append("".join(free_string))
   640          stashed_blocks -= free_size
   641  
   642        if common.OPTIONS.cache_size is not None:
   643          # Sanity check: abort if we're going to need more stash space than
   644          # the allowed size (cache_size * threshold). There are two purposes
   645          # of having a threshold here. a) Part of the cache may have been
   646          # occupied by some recovery logs. b) It will buy us some time to deal
   647          # with the oversize issue.
   648          cache_size = common.OPTIONS.cache_size
   649          stash_threshold = common.OPTIONS.stash_threshold
   650          max_allowed = cache_size * stash_threshold
   651          assert max_stashed_blocks * self.tgt.blocksize <= max_allowed, \
   652                 'Stash size %d (%d * %d) exceeds the limit %d (%d * %.2f)' % (
   653                     max_stashed_blocks * self.tgt.blocksize, max_stashed_blocks,
   654                     self.tgt.blocksize, max_allowed, cache_size,
   655                     stash_threshold)
   656  
   657      self.touched_src_sha1 = self.src.RangeSha1(self.touched_src_ranges)
   658  
   659      # Zero out extended blocks as a workaround for bug 20881595.
   660      if self.tgt.extended:
   661        assert (WriteSplitTransfers(out, "zero", self.tgt.extended) ==
   662                self.tgt.extended.size())
   663        total += self.tgt.extended.size()
   664  
   665      # We erase all the blocks on the partition that a) don't contain useful
   666      # data in the new image; b) will not be touched by dm-verity. Out of those
   667      # blocks, we erase the ones that won't be used in this update at the
   668      # beginning of an update. The rest would be erased at the end. This is to
   669      # work around the eMMC issue observed on some devices, which may otherwise
   670      # get starving for clean blocks and thus fail the update. (b/28347095)
   671      all_tgt = RangeSet(data=(0, self.tgt.total_blocks))
   672      all_tgt_minus_extended = all_tgt.subtract(self.tgt.extended)
   673      new_dontcare = all_tgt_minus_extended.subtract(self.tgt.care_map)
   674  
   675      erase_first = new_dontcare.subtract(self.touched_src_ranges)
   676      if erase_first:
   677        out.insert(0, "erase %s\n" % (erase_first.to_string_raw(),))
   678  
   679      erase_last = new_dontcare.subtract(erase_first)
   680      if erase_last:
   681        out.append("erase %s\n" % (erase_last.to_string_raw(),))
   682  
   683      out.insert(0, "%d\n" % (self.version,))   # format version number
   684      out.insert(1, "%d\n" % (total,))
   685      # v3+: the number of stash slots is unused.
   686      out.insert(2, "0\n")
   687      out.insert(3, str(max_stashed_blocks) + "\n")
   688  
   689      with open(prefix + ".transfer.list", "wb") as f:
   690        for i in out:
   691          f.write(i)
   692  
   693      self._max_stashed_size = max_stashed_blocks * self.tgt.blocksize
   694      OPTIONS = common.OPTIONS
   695      if OPTIONS.cache_size is not None:
   696        max_allowed = OPTIONS.cache_size * OPTIONS.stash_threshold
   697        print("max stashed blocks: %d  (%d bytes), "
   698              "limit: %d bytes (%.2f%%)\n" % (
   699                  max_stashed_blocks, self._max_stashed_size, max_allowed,
   700                  self._max_stashed_size * 100.0 / max_allowed))
   701      else:
   702        print("max stashed blocks: %d  (%d bytes), limit: <unknown>\n" % (
   703            max_stashed_blocks, self._max_stashed_size))
   704  
   705    def ReviseStashSize(self):
   706      print("Revising stash size...")
   707      stash_map = {}
   708  
   709      # Create the map between a stash and its def/use points. For example, for a
   710      # given stash of (raw_id, sr), stash_map[raw_id] = (sr, def_cmd, use_cmd).
   711      for xf in self.transfers:
   712        # Command xf defines (stores) all the stashes in stash_before.
   713        for stash_raw_id, sr in xf.stash_before:
   714          stash_map[stash_raw_id] = (sr, xf)
   715  
   716        # Record all the stashes command xf uses.
   717        for stash_raw_id, _ in xf.use_stash:
   718          stash_map[stash_raw_id] += (xf,)
   719  
   720      # Compute the maximum blocks available for stash based on /cache size and
   721      # the threshold.
   722      cache_size = common.OPTIONS.cache_size
   723      stash_threshold = common.OPTIONS.stash_threshold
   724      max_allowed = cache_size * stash_threshold / self.tgt.blocksize
   725  
   726      # See the comments for 'stashes' in WriteTransfers().
   727      stashes = {}
   728      stashed_blocks = 0
   729      new_blocks = 0
   730  
   731      # Now go through all the commands. Compute the required stash size on the
   732      # fly. If a command requires excess stash than available, it deletes the
   733      # stash by replacing the command that uses the stash with a "new" command
   734      # instead.
   735      for xf in self.transfers:
   736        replaced_cmds = []
   737  
   738        # xf.stash_before generates explicit stash commands.
   739        for stash_raw_id, sr in xf.stash_before:
   740          # Check the post-command stashed_blocks.
   741          stashed_blocks_after = stashed_blocks
   742          sh = self.src.RangeSha1(sr)
   743          if sh not in stashes:
   744            stashed_blocks_after += sr.size()
   745  
   746          if stashed_blocks_after > max_allowed:
   747            # We cannot stash this one for a later command. Find out the command
   748            # that will use this stash and replace the command with "new".
   749            use_cmd = stash_map[stash_raw_id][2]
   750            replaced_cmds.append(use_cmd)
   751            print("%10d  %9s  %s" % (sr.size(), "explicit", use_cmd))
   752          else:
   753            # Update the stashes map.
   754            if sh in stashes:
   755              stashes[sh] += 1
   756            else:
   757              stashes[sh] = 1
   758            stashed_blocks = stashed_blocks_after
   759  
   760        # "move" and "diff" may introduce implicit stashes in BBOTA v3. Prior to
   761        # ComputePatches(), they both have the style of "diff".
   762        if xf.style == "diff":
   763          assert xf.tgt_ranges and xf.src_ranges
   764          if xf.src_ranges.overlaps(xf.tgt_ranges):
   765            if stashed_blocks + xf.src_ranges.size() > max_allowed:
   766              replaced_cmds.append(xf)
   767              print("%10d  %9s  %s" % (xf.src_ranges.size(), "implicit", xf))
   768  
   769        # Replace the commands in replaced_cmds with "new"s.
   770        for cmd in replaced_cmds:
   771          # It no longer uses any commands in "use_stash". Remove the def points
   772          # for all those stashes.
   773          for stash_raw_id, sr in cmd.use_stash:
   774            def_cmd = stash_map[stash_raw_id][1]
   775            assert (stash_raw_id, sr) in def_cmd.stash_before
   776            def_cmd.stash_before.remove((stash_raw_id, sr))
   777  
   778          # Add up blocks that violates space limit and print total number to
   779          # screen later.
   780          new_blocks += cmd.tgt_ranges.size()
   781          cmd.ConvertToNew()
   782  
   783        # xf.use_stash may generate free commands.
   784        for _, sr in xf.use_stash:
   785          sh = self.src.RangeSha1(sr)
   786          assert sh in stashes
   787          stashes[sh] -= 1
   788          if stashes[sh] == 0:
   789            stashed_blocks -= sr.size()
   790            stashes.pop(sh)
   791  
   792      num_of_bytes = new_blocks * self.tgt.blocksize
   793      print("  Total %d blocks (%d bytes) are packed as new blocks due to "
   794            "insufficient cache size." % (new_blocks, num_of_bytes))
   795      return new_blocks
   796  
   797    def ComputePatches(self, prefix):
   798      print("Reticulating splines...")
   799      diff_queue = []
   800      patch_num = 0
   801      with open(prefix + ".new.dat", "wb") as new_f:
   802        for index, xf in enumerate(self.transfers):
   803          if xf.style == "zero":
   804            tgt_size = xf.tgt_ranges.size() * self.tgt.blocksize
   805            print("%10d %10d (%6.2f%%) %7s %s %s" % (
   806                tgt_size, tgt_size, 100.0, xf.style, xf.tgt_name,
   807                str(xf.tgt_ranges)))
   808  
   809          elif xf.style == "new":
   810            self.tgt.WriteRangeDataToFd(xf.tgt_ranges, new_f)
   811            tgt_size = xf.tgt_ranges.size() * self.tgt.blocksize
   812            print("%10d %10d (%6.2f%%) %7s %s %s" % (
   813                tgt_size, tgt_size, 100.0, xf.style,
   814                xf.tgt_name, str(xf.tgt_ranges)))
   815  
   816          elif xf.style == "diff":
   817            # We can't compare src and tgt directly because they may have
   818            # the same content but be broken up into blocks differently, eg:
   819            #
   820            #    ["he", "llo"]  vs  ["h", "ello"]
   821            #
   822            # We want those to compare equal, ideally without having to
   823            # actually concatenate the strings (these may be tens of
   824            # megabytes).
   825            if xf.src_sha1 == xf.tgt_sha1:
   826              # These are identical; we don't need to generate a patch,
   827              # just issue copy commands on the device.
   828              xf.style = "move"
   829              xf.patch = None
   830              tgt_size = xf.tgt_ranges.size() * self.tgt.blocksize
   831              if xf.src_ranges != xf.tgt_ranges:
   832                print("%10d %10d (%6.2f%%) %7s %s %s (from %s)" % (
   833                    tgt_size, tgt_size, 100.0, xf.style,
   834                    xf.tgt_name if xf.tgt_name == xf.src_name else (
   835                        xf.tgt_name + " (from " + xf.src_name + ")"),
   836                    str(xf.tgt_ranges), str(xf.src_ranges)))
   837            else:
   838              if xf.patch:
   839                # We have already generated the patch with imgdiff. Check if the
   840                # transfer is intact.
   841                assert not self.disable_imgdiff
   842                imgdiff = True
   843                if (xf.src_ranges.extra.get('trimmed') or
   844                    xf.tgt_ranges.extra.get('trimmed')):
   845                  imgdiff = False
   846                  xf.patch = None
   847              else:
   848                imgdiff = self.CanUseImgdiff(
   849                    xf.tgt_name, xf.tgt_ranges, xf.src_ranges)
   850              xf.style = "imgdiff" if imgdiff else "bsdiff"
   851              diff_queue.append((index, imgdiff, patch_num))
   852              patch_num += 1
   853  
   854          else:
   855            assert False, "unknown style " + xf.style
   856  
   857      if diff_queue:
   858        if self.threads > 1:
   859          print("Computing patches (using %d threads)..." % (self.threads,))
   860        else:
   861          print("Computing patches...")
   862  
   863        diff_total = len(diff_queue)
   864        patches = [None] * diff_total
   865        error_messages = []
   866  
   867        # Using multiprocessing doesn't give additional benefits, due to the
   868        # pattern of the code. The diffing work is done by subprocess.call, which
   869        # already runs in a separate process (not affected much by the GIL -
   870        # Global Interpreter Lock). Using multiprocess also requires either a)
   871        # writing the diff input files in the main process before forking, or b)
   872        # reopening the image file (SparseImage) in the worker processes. Doing
   873        # neither of them further improves the performance.
   874        lock = threading.Lock()
   875        def diff_worker():
   876          while True:
   877            with lock:
   878              if not diff_queue:
   879                return
   880              xf_index, imgdiff, patch_index = diff_queue.pop()
   881              xf = self.transfers[xf_index]
   882  
   883              if sys.stdout.isatty():
   884                diff_left = len(diff_queue)
   885                progress = (diff_total - diff_left) * 100 / diff_total
   886                # '\033[K' is to clear to EOL.
   887                print(' [%3d%%] %s\033[K' % (progress, xf.tgt_name), end='\r')
   888                sys.stdout.flush()
   889  
   890            patch = xf.patch
   891            if not patch:
   892              src_ranges = xf.src_ranges
   893              tgt_ranges = xf.tgt_ranges
   894  
   895              src_file = common.MakeTempFile(prefix="src-")
   896              with open(src_file, "wb") as fd:
   897                self.src.WriteRangeDataToFd(src_ranges, fd)
   898  
   899              tgt_file = common.MakeTempFile(prefix="tgt-")
   900              with open(tgt_file, "wb") as fd:
   901                self.tgt.WriteRangeDataToFd(tgt_ranges, fd)
   902  
   903              message = []
   904              try:
   905                patch = compute_patch(src_file, tgt_file, imgdiff)
   906              except ValueError as e:
   907                message.append(
   908                    "Failed to generate %s for %s: tgt=%s, src=%s:\n%s" % (
   909                        "imgdiff" if imgdiff else "bsdiff",
   910                        xf.tgt_name if xf.tgt_name == xf.src_name else
   911                        xf.tgt_name + " (from " + xf.src_name + ")",
   912                        xf.tgt_ranges, xf.src_ranges, e.message))
   913              if message:
   914                with lock:
   915                  error_messages.extend(message)
   916  
   917            with lock:
   918              patches[patch_index] = (xf_index, patch)
   919  
   920        threads = [threading.Thread(target=diff_worker)
   921                   for _ in range(self.threads)]
   922        for th in threads:
   923          th.start()
   924        while threads:
   925          threads.pop().join()
   926  
   927        if sys.stdout.isatty():
   928          print('\n')
   929  
   930        if error_messages:
   931          print('ERROR:')
   932          print('\n'.join(error_messages))
   933          print('\n\n\n')
   934          sys.exit(1)
   935      else:
   936        patches = []
   937  
   938      offset = 0
   939      with open(prefix + ".patch.dat", "wb") as patch_fd:
   940        for index, patch in patches:
   941          xf = self.transfers[index]
   942          xf.patch_len = len(patch)
   943          xf.patch_start = offset
   944          offset += xf.patch_len
   945          patch_fd.write(patch)
   946  
   947          if common.OPTIONS.verbose:
   948            tgt_size = xf.tgt_ranges.size() * self.tgt.blocksize
   949            print("%10d %10d (%6.2f%%) %7s %s %s %s" % (
   950                xf.patch_len, tgt_size, xf.patch_len * 100.0 / tgt_size,
   951                xf.style,
   952                xf.tgt_name if xf.tgt_name == xf.src_name else (
   953                    xf.tgt_name + " (from " + xf.src_name + ")"),
   954                xf.tgt_ranges, xf.src_ranges))
   955  
   956    def AssertSha1Good(self):
   957      """Check the SHA-1 of the src & tgt blocks in the transfer list.
   958  
   959      Double check the SHA-1 value to avoid the issue in b/71908713, where
   960      SparseImage.RangeSha1() messed up with the hash calculation in multi-thread
   961      environment. That specific problem has been fixed by protecting the
   962      underlying generator function 'SparseImage._GetRangeData()' with lock.
   963      """
   964      for xf in self.transfers:
   965        tgt_sha1 = self.tgt.RangeSha1(xf.tgt_ranges)
   966        assert xf.tgt_sha1 == tgt_sha1
   967        if xf.style == "diff":
   968          src_sha1 = self.src.RangeSha1(xf.src_ranges)
   969          assert xf.src_sha1 == src_sha1
   970  
   971    def AssertSequenceGood(self):
   972      # Simulate the sequences of transfers we will output, and check that:
   973      # - we never read a block after writing it, and
   974      # - we write every block we care about exactly once.
   975  
   976      # Start with no blocks having been touched yet.
   977      touched = array.array("B", "\0" * self.tgt.total_blocks)
   978  
   979      # Imagine processing the transfers in order.
   980      for xf in self.transfers:
   981        # Check that the input blocks for this transfer haven't yet been touched.
   982  
   983        x = xf.src_ranges
   984        for _, sr in xf.use_stash:
   985          x = x.subtract(sr)
   986  
   987        for s, e in x:
   988          # Source image could be larger. Don't check the blocks that are in the
   989          # source image only. Since they are not in 'touched', and won't ever
   990          # be touched.
   991          for i in range(s, min(e, self.tgt.total_blocks)):
   992            assert touched[i] == 0
   993  
   994        # Check that the output blocks for this transfer haven't yet
   995        # been touched, and touch all the blocks written by this
   996        # transfer.
   997        for s, e in xf.tgt_ranges:
   998          for i in range(s, e):
   999            assert touched[i] == 0
  1000            touched[i] = 1
  1001  
  1002      # Check that we've written every target block.
  1003      for s, e in self.tgt.care_map:
  1004        for i in range(s, e):
  1005          assert touched[i] == 1
  1006  
  1007    def ImproveVertexSequence(self):
  1008      print("Improving vertex order...")
  1009  
  1010      # At this point our digraph is acyclic; we reversed any edges that
  1011      # were backwards in the heuristically-generated sequence.  The
  1012      # previously-generated order is still acceptable, but we hope to
  1013      # find a better order that needs less memory for stashed data.
  1014      # Now we do a topological sort to generate a new vertex order,
  1015      # using a greedy algorithm to choose which vertex goes next
  1016      # whenever we have a choice.
  1017  
  1018      # Make a copy of the edge set; this copy will get destroyed by the
  1019      # algorithm.
  1020      for xf in self.transfers:
  1021        xf.incoming = xf.goes_after.copy()
  1022        xf.outgoing = xf.goes_before.copy()
  1023  
  1024      L = []   # the new vertex order
  1025  
  1026      # S is the set of sources in the remaining graph; we always choose
  1027      # the one that leaves the least amount of stashed data after it's
  1028      # executed.
  1029      S = [(u.NetStashChange(), u.order, u) for u in self.transfers
  1030           if not u.incoming]
  1031      heapq.heapify(S)
  1032  
  1033      while S:
  1034        _, _, xf = heapq.heappop(S)
  1035        L.append(xf)
  1036        for u in xf.outgoing:
  1037          del u.incoming[xf]
  1038          if not u.incoming:
  1039            heapq.heappush(S, (u.NetStashChange(), u.order, u))
  1040  
  1041      # if this fails then our graph had a cycle.
  1042      assert len(L) == len(self.transfers)
  1043  
  1044      self.transfers = L
  1045      for i, xf in enumerate(L):
  1046        xf.order = i
  1047  
  1048    def RemoveBackwardEdges(self):
  1049      print("Removing backward edges...")
  1050      in_order = 0
  1051      out_of_order = 0
  1052      lost_source = 0
  1053  
  1054      for xf in self.transfers:
  1055        lost = 0
  1056        size = xf.src_ranges.size()
  1057        for u in xf.goes_before:
  1058          # xf should go before u
  1059          if xf.order < u.order:
  1060            # it does, hurray!
  1061            in_order += 1
  1062          else:
  1063            # it doesn't, boo.  trim the blocks that u writes from xf's
  1064            # source, so that xf can go after u.
  1065            out_of_order += 1
  1066            assert xf.src_ranges.overlaps(u.tgt_ranges)
  1067            xf.src_ranges = xf.src_ranges.subtract(u.tgt_ranges)
  1068            xf.src_ranges.extra['trimmed'] = True
  1069  
  1070        if xf.style == "diff" and not xf.src_ranges:
  1071          # nothing left to diff from; treat as new data
  1072          xf.style = "new"
  1073  
  1074        lost = size - xf.src_ranges.size()
  1075        lost_source += lost
  1076  
  1077      print(("  %d/%d dependencies (%.2f%%) were violated; "
  1078             "%d source blocks removed.") %
  1079            (out_of_order, in_order + out_of_order,
  1080             (out_of_order * 100.0 / (in_order + out_of_order))
  1081             if (in_order + out_of_order) else 0.0,
  1082             lost_source))
  1083  
  1084    def ReverseBackwardEdges(self):
  1085      """Reverse unsatisfying edges and compute pairs of stashed blocks.
  1086  
  1087      For each transfer, make sure it properly stashes the blocks it touches and
  1088      will be used by later transfers. It uses pairs of (stash_raw_id, range) to
  1089      record the blocks to be stashed. 'stash_raw_id' is an id that uniquely
  1090      identifies each pair. Note that for the same range (e.g. RangeSet("1-5")),
  1091      it is possible to have multiple pairs with different 'stash_raw_id's. Each
  1092      'stash_raw_id' will be consumed by one transfer. In BBOTA v3+, identical
  1093      blocks will be written to the same stash slot in WriteTransfers().
  1094      """
  1095  
  1096      print("Reversing backward edges...")
  1097      in_order = 0
  1098      out_of_order = 0
  1099      stash_raw_id = 0
  1100      stash_size = 0
  1101  
  1102      for xf in self.transfers:
  1103        for u in xf.goes_before.copy():
  1104          # xf should go before u
  1105          if xf.order < u.order:
  1106            # it does, hurray!
  1107            in_order += 1
  1108          else:
  1109            # it doesn't, boo.  modify u to stash the blocks that it
  1110            # writes that xf wants to read, and then require u to go
  1111            # before xf.
  1112            out_of_order += 1
  1113  
  1114            overlap = xf.src_ranges.intersect(u.tgt_ranges)
  1115            assert overlap
  1116  
  1117            u.stash_before.append((stash_raw_id, overlap))
  1118            xf.use_stash.append((stash_raw_id, overlap))
  1119            stash_raw_id += 1
  1120            stash_size += overlap.size()
  1121  
  1122            # reverse the edge direction; now xf must go after u
  1123            del xf.goes_before[u]
  1124            del u.goes_after[xf]
  1125            xf.goes_after[u] = None    # value doesn't matter
  1126            u.goes_before[xf] = None
  1127  
  1128      print(("  %d/%d dependencies (%.2f%%) were violated; "
  1129             "%d source blocks stashed.") %
  1130            (out_of_order, in_order + out_of_order,
  1131             (out_of_order * 100.0 / (in_order + out_of_order))
  1132             if (in_order + out_of_order) else 0.0,
  1133             stash_size))
  1134  
  1135    def FindVertexSequence(self):
  1136      print("Finding vertex sequence...")
  1137  
  1138      # This is based on "A Fast & Effective Heuristic for the Feedback
  1139      # Arc Set Problem" by P. Eades, X. Lin, and W.F. Smyth.  Think of
  1140      # it as starting with the digraph G and moving all the vertices to
  1141      # be on a horizontal line in some order, trying to minimize the
  1142      # number of edges that end up pointing to the left.  Left-pointing
  1143      # edges will get removed to turn the digraph into a DAG.  In this
  1144      # case each edge has a weight which is the number of source blocks
  1145      # we'll lose if that edge is removed; we try to minimize the total
  1146      # weight rather than just the number of edges.
  1147  
  1148      # Make a copy of the edge set; this copy will get destroyed by the
  1149      # algorithm.
  1150      for xf in self.transfers:
  1151        xf.incoming = xf.goes_after.copy()
  1152        xf.outgoing = xf.goes_before.copy()
  1153        xf.score = sum(xf.outgoing.values()) - sum(xf.incoming.values())
  1154  
  1155      # We use an OrderedDict instead of just a set so that the output
  1156      # is repeatable; otherwise it would depend on the hash values of
  1157      # the transfer objects.
  1158      G = OrderedDict()
  1159      for xf in self.transfers:
  1160        G[xf] = None
  1161      s1 = deque()  # the left side of the sequence, built from left to right
  1162      s2 = deque()  # the right side of the sequence, built from right to left
  1163  
  1164      heap = []
  1165      for xf in self.transfers:
  1166        xf.heap_item = HeapItem(xf)
  1167        heap.append(xf.heap_item)
  1168      heapq.heapify(heap)
  1169  
  1170      # Use OrderedDict() instead of set() to preserve the insertion order. Need
  1171      # to use 'sinks[key] = None' to add key into the set. sinks will look like
  1172      # { key1: None, key2: None, ... }.
  1173      sinks = OrderedDict.fromkeys(u for u in G if not u.outgoing)
  1174      sources = OrderedDict.fromkeys(u for u in G if not u.incoming)
  1175  
  1176      def adjust_score(iu, delta):
  1177        iu.score += delta
  1178        iu.heap_item.clear()
  1179        iu.heap_item = HeapItem(iu)
  1180        heapq.heappush(heap, iu.heap_item)
  1181  
  1182      while G:
  1183        # Put all sinks at the end of the sequence.
  1184        while sinks:
  1185          new_sinks = OrderedDict()
  1186          for u in sinks:
  1187            if u not in G:
  1188              continue
  1189            s2.appendleft(u)
  1190            del G[u]
  1191            for iu in u.incoming:
  1192              adjust_score(iu, -iu.outgoing.pop(u))
  1193              if not iu.outgoing:
  1194                new_sinks[iu] = None
  1195          sinks = new_sinks
  1196  
  1197        # Put all the sources at the beginning of the sequence.
  1198        while sources:
  1199          new_sources = OrderedDict()
  1200          for u in sources:
  1201            if u not in G:
  1202              continue
  1203            s1.append(u)
  1204            del G[u]
  1205            for iu in u.outgoing:
  1206              adjust_score(iu, +iu.incoming.pop(u))
  1207              if not iu.incoming:
  1208                new_sources[iu] = None
  1209          sources = new_sources
  1210  
  1211        if not G:
  1212          break
  1213  
  1214        # Find the "best" vertex to put next.  "Best" is the one that
  1215        # maximizes the net difference in source blocks saved we get by
  1216        # pretending it's a source rather than a sink.
  1217  
  1218        while True:
  1219          u = heapq.heappop(heap)
  1220          if u and u.item in G:
  1221            u = u.item
  1222            break
  1223  
  1224        s1.append(u)
  1225        del G[u]
  1226        for iu in u.outgoing:
  1227          adjust_score(iu, +iu.incoming.pop(u))
  1228          if not iu.incoming:
  1229            sources[iu] = None
  1230  
  1231        for iu in u.incoming:
  1232          adjust_score(iu, -iu.outgoing.pop(u))
  1233          if not iu.outgoing:
  1234            sinks[iu] = None
  1235  
  1236      # Now record the sequence in the 'order' field of each transfer,
  1237      # and by rearranging self.transfers to be in the chosen sequence.
  1238  
  1239      new_transfers = []
  1240      for x in itertools.chain(s1, s2):
  1241        x.order = len(new_transfers)
  1242        new_transfers.append(x)
  1243        del x.incoming
  1244        del x.outgoing
  1245  
  1246      self.transfers = new_transfers
  1247  
  1248    def GenerateDigraph(self):
  1249      print("Generating digraph...")
  1250  
  1251      # Each item of source_ranges will be:
  1252      #   - None, if that block is not used as a source,
  1253      #   - an ordered set of transfers.
  1254      source_ranges = []
  1255      for b in self.transfers:
  1256        for s, e in b.src_ranges:
  1257          if e > len(source_ranges):
  1258            source_ranges.extend([None] * (e-len(source_ranges)))
  1259          for i in range(s, e):
  1260            if source_ranges[i] is None:
  1261              source_ranges[i] = OrderedDict.fromkeys([b])
  1262            else:
  1263              source_ranges[i][b] = None
  1264  
  1265      for a in self.transfers:
  1266        intersections = OrderedDict()
  1267        for s, e in a.tgt_ranges:
  1268          for i in range(s, e):
  1269            if i >= len(source_ranges):
  1270              break
  1271            # Add all the Transfers in source_ranges[i] to the (ordered) set.
  1272            if source_ranges[i] is not None:
  1273              for j in source_ranges[i]:
  1274                intersections[j] = None
  1275  
  1276        for b in intersections:
  1277          if a is b:
  1278            continue
  1279  
  1280          # If the blocks written by A are read by B, then B needs to go before A.
  1281          i = a.tgt_ranges.intersect(b.src_ranges)
  1282          if i:
  1283            if b.src_name == "__ZERO":
  1284              # the cost of removing source blocks for the __ZERO domain
  1285              # is (nearly) zero.
  1286              size = 0
  1287            else:
  1288              size = i.size()
  1289            b.goes_before[a] = size
  1290            a.goes_after[b] = size
  1291  
  1292    def FindTransfers(self):
  1293      """Parse the file_map to generate all the transfers."""
  1294  
  1295      def AddSplitTransfersWithFixedSizeChunks(tgt_name, src_name, tgt_ranges,
  1296                                               src_ranges, style, by_id):
  1297        """Add one or multiple Transfer()s by splitting large files.
  1298  
  1299        For BBOTA v3, we need to stash source blocks for resumable feature.
  1300        However, with the growth of file size and the shrink of the cache
  1301        partition source blocks are too large to be stashed. If a file occupies
  1302        too many blocks, we split it into smaller pieces by getting multiple
  1303        Transfer()s.
  1304  
  1305        The downside is that after splitting, we may increase the package size
  1306        since the split pieces don't align well. According to our experiments,
  1307        1/8 of the cache size as the per-piece limit appears to be optimal.
  1308        Compared to the fixed 1024-block limit, it reduces the overall package
  1309        size by 30% for volantis, and 20% for angler and bullhead."""
  1310  
  1311        pieces = 0
  1312        while (tgt_ranges.size() > max_blocks_per_transfer and
  1313               src_ranges.size() > max_blocks_per_transfer):
  1314          tgt_split_name = "%s-%d" % (tgt_name, pieces)
  1315          src_split_name = "%s-%d" % (src_name, pieces)
  1316          tgt_first = tgt_ranges.first(max_blocks_per_transfer)
  1317          src_first = src_ranges.first(max_blocks_per_transfer)
  1318  
  1319          Transfer(tgt_split_name, src_split_name, tgt_first, src_first,
  1320                   self.tgt.RangeSha1(tgt_first), self.src.RangeSha1(src_first),
  1321                   style, by_id)
  1322  
  1323          tgt_ranges = tgt_ranges.subtract(tgt_first)
  1324          src_ranges = src_ranges.subtract(src_first)
  1325          pieces += 1
  1326  
  1327        # Handle remaining blocks.
  1328        if tgt_ranges.size() or src_ranges.size():
  1329          # Must be both non-empty.
  1330          assert tgt_ranges.size() and src_ranges.size()
  1331          tgt_split_name = "%s-%d" % (tgt_name, pieces)
  1332          src_split_name = "%s-%d" % (src_name, pieces)
  1333          Transfer(tgt_split_name, src_split_name, tgt_ranges, src_ranges,
  1334                   self.tgt.RangeSha1(tgt_ranges), self.src.RangeSha1(src_ranges),
  1335                   style, by_id)
  1336  
  1337      def AddSplitTransfers(tgt_name, src_name, tgt_ranges, src_ranges, style,
  1338                            by_id):
  1339        """Find all the zip files and split the others with a fixed chunk size.
  1340  
  1341        This function will construct a list of zip archives, which will later be
  1342        split by imgdiff to reduce the final patch size. For the other files,
  1343        we will plainly split them based on a fixed chunk size with the potential
  1344        patch size penalty.
  1345        """
  1346  
  1347        assert style == "diff"
  1348  
  1349        # Change nothing for small files.
  1350        if (tgt_ranges.size() <= max_blocks_per_transfer and
  1351            src_ranges.size() <= max_blocks_per_transfer):
  1352          Transfer(tgt_name, src_name, tgt_ranges, src_ranges,
  1353                   self.tgt.RangeSha1(tgt_ranges), self.src.RangeSha1(src_ranges),
  1354                   style, by_id)
  1355          return
  1356  
  1357        # Split large APKs with imgdiff, if possible. We're intentionally checking
  1358        # file types one more time (CanUseImgdiff() checks that as well), before
  1359        # calling the costly RangeSha1()s.
  1360        if (self.FileTypeSupportedByImgdiff(tgt_name) and
  1361            self.tgt.RangeSha1(tgt_ranges) != self.src.RangeSha1(src_ranges)):
  1362          if self.CanUseImgdiff(tgt_name, tgt_ranges, src_ranges, True):
  1363            large_apks.append((tgt_name, src_name, tgt_ranges, src_ranges))
  1364            return
  1365  
  1366        AddSplitTransfersWithFixedSizeChunks(tgt_name, src_name, tgt_ranges,
  1367                                             src_ranges, style, by_id)
  1368  
  1369      def AddTransfer(tgt_name, src_name, tgt_ranges, src_ranges, style, by_id,
  1370                      split=False):
  1371        """Wrapper function for adding a Transfer()."""
  1372  
  1373        # We specialize diff transfers only (which covers bsdiff/imgdiff/move);
  1374        # otherwise add the Transfer() as is.
  1375        if style != "diff" or not split:
  1376          Transfer(tgt_name, src_name, tgt_ranges, src_ranges,
  1377                   self.tgt.RangeSha1(tgt_ranges), self.src.RangeSha1(src_ranges),
  1378                   style, by_id)
  1379          return
  1380  
  1381        # Handle .odex files specially to analyze the block-wise difference. If
  1382        # most of the blocks are identical with only few changes (e.g. header),
  1383        # we will patch the changed blocks only. This avoids stashing unchanged
  1384        # blocks while patching. We limit the analysis to files without size
  1385        # changes only. This is to avoid sacrificing the OTA generation cost too
  1386        # much.
  1387        if (tgt_name.split(".")[-1].lower() == 'odex' and
  1388            tgt_ranges.size() == src_ranges.size()):
  1389  
  1390          # 0.5 threshold can be further tuned. The tradeoff is: if only very
  1391          # few blocks remain identical, we lose the opportunity to use imgdiff
  1392          # that may have better compression ratio than bsdiff.
  1393          crop_threshold = 0.5
  1394  
  1395          tgt_skipped = RangeSet()
  1396          src_skipped = RangeSet()
  1397          tgt_size = tgt_ranges.size()
  1398          tgt_changed = 0
  1399          for src_block, tgt_block in zip(src_ranges.next_item(),
  1400                                          tgt_ranges.next_item()):
  1401            src_rs = RangeSet(str(src_block))
  1402            tgt_rs = RangeSet(str(tgt_block))
  1403            if self.src.ReadRangeSet(src_rs) == self.tgt.ReadRangeSet(tgt_rs):
  1404              tgt_skipped = tgt_skipped.union(tgt_rs)
  1405              src_skipped = src_skipped.union(src_rs)
  1406            else:
  1407              tgt_changed += tgt_rs.size()
  1408  
  1409            # Terminate early if no clear sign of benefits.
  1410            if tgt_changed > tgt_size * crop_threshold:
  1411              break
  1412  
  1413          if tgt_changed < tgt_size * crop_threshold:
  1414            assert tgt_changed + tgt_skipped.size() == tgt_size
  1415            print('%10d %10d (%6.2f%%) %s' % (
  1416                tgt_skipped.size(), tgt_size,
  1417                tgt_skipped.size() * 100.0 / tgt_size, tgt_name))
  1418            AddSplitTransfers(
  1419                "%s-skipped" % (tgt_name,),
  1420                "%s-skipped" % (src_name,),
  1421                tgt_skipped, src_skipped, style, by_id)
  1422  
  1423            # Intentionally change the file extension to avoid being imgdiff'd as
  1424            # the files are no longer in their original format.
  1425            tgt_name = "%s-cropped" % (tgt_name,)
  1426            src_name = "%s-cropped" % (src_name,)
  1427            tgt_ranges = tgt_ranges.subtract(tgt_skipped)
  1428            src_ranges = src_ranges.subtract(src_skipped)
  1429  
  1430            # Possibly having no changed blocks.
  1431            if not tgt_ranges:
  1432              return
  1433  
  1434        # Add the transfer(s).
  1435        AddSplitTransfers(
  1436            tgt_name, src_name, tgt_ranges, src_ranges, style, by_id)
  1437  
  1438      def ParseAndValidateSplitInfo(patch_size, tgt_ranges, src_ranges,
  1439                                    split_info):
  1440        """Parse the split_info and return a list of info tuples.
  1441  
  1442        Args:
  1443          patch_size: total size of the patch file.
  1444          tgt_ranges: Ranges of the target file within the original image.
  1445          src_ranges: Ranges of the source file within the original image.
  1446          split_info format:
  1447            imgdiff version#
  1448            count of pieces
  1449            <patch_size_1> <tgt_size_1> <src_ranges_1>
  1450            ...
  1451            <patch_size_n> <tgt_size_n> <src_ranges_n>
  1452  
  1453        Returns:
  1454          [patch_start, patch_len, split_tgt_ranges, split_src_ranges]
  1455        """
  1456  
  1457        version = int(split_info[0])
  1458        assert version == 2
  1459        count = int(split_info[1])
  1460        assert len(split_info) - 2 == count
  1461  
  1462        split_info_list = []
  1463        patch_start = 0
  1464        tgt_remain = copy.deepcopy(tgt_ranges)
  1465        # each line has the format <patch_size>, <tgt_size>, <src_ranges>
  1466        for line in split_info[2:]:
  1467          info = line.split()
  1468          assert len(info) == 3
  1469          patch_length = int(info[0])
  1470  
  1471          split_tgt_size = int(info[1])
  1472          assert split_tgt_size % 4096 == 0
  1473          assert split_tgt_size / 4096 <= tgt_remain.size()
  1474          split_tgt_ranges = tgt_remain.first(split_tgt_size / 4096)
  1475          tgt_remain = tgt_remain.subtract(split_tgt_ranges)
  1476  
  1477          # Find the split_src_ranges within the image file from its relative
  1478          # position in file.
  1479          split_src_indices = RangeSet.parse_raw(info[2])
  1480          split_src_ranges = RangeSet()
  1481          for r in split_src_indices:
  1482            curr_range = src_ranges.first(r[1]).subtract(src_ranges.first(r[0]))
  1483            assert not split_src_ranges.overlaps(curr_range)
  1484            split_src_ranges = split_src_ranges.union(curr_range)
  1485  
  1486          split_info_list.append((patch_start, patch_length,
  1487                                  split_tgt_ranges, split_src_ranges))
  1488          patch_start += patch_length
  1489  
  1490        # Check that the sizes of all the split pieces add up to the final file
  1491        # size for patch and target.
  1492        assert tgt_remain.size() == 0
  1493        assert patch_start == patch_size
  1494        return split_info_list
  1495  
  1496      def SplitLargeApks():
  1497        """Split the large apks files.
  1498  
  1499        Example: Chrome.apk will be split into
  1500          src-0: Chrome.apk-0, tgt-0: Chrome.apk-0
  1501          src-1: Chrome.apk-1, tgt-1: Chrome.apk-1
  1502          ...
  1503  
  1504        After the split, the target pieces are continuous and block aligned; and
  1505        the source pieces are mutually exclusive. During the split, we also
  1506        generate and save the image patch between src-X & tgt-X. This patch will
  1507        be valid because the block ranges of src-X & tgt-X will always stay the
  1508        same afterwards; but there's a chance we don't use the patch if we
  1509        convert the "diff" command into "new" or "move" later.
  1510        """
  1511  
  1512        while True:
  1513          with transfer_lock:
  1514            if not large_apks:
  1515              return
  1516            tgt_name, src_name, tgt_ranges, src_ranges = large_apks.pop(0)
  1517  
  1518          src_file = common.MakeTempFile(prefix="src-")
  1519          tgt_file = common.MakeTempFile(prefix="tgt-")
  1520          with open(src_file, "wb") as src_fd:
  1521            self.src.WriteRangeDataToFd(src_ranges, src_fd)
  1522          with open(tgt_file, "wb") as tgt_fd:
  1523            self.tgt.WriteRangeDataToFd(tgt_ranges, tgt_fd)
  1524  
  1525          patch_file = common.MakeTempFile(prefix="patch-")
  1526          patch_info_file = common.MakeTempFile(prefix="split_info-")
  1527          cmd = ["imgdiff", "-z",
  1528                 "--block-limit={}".format(max_blocks_per_transfer),
  1529                 "--split-info=" + patch_info_file,
  1530                 src_file, tgt_file, patch_file]
  1531          p = common.Run(cmd, stdout=subprocess.PIPE, stderr=subprocess.STDOUT)
  1532          imgdiff_output, _ = p.communicate()
  1533          assert p.returncode == 0, \
  1534              "Failed to create imgdiff patch between {} and {}:\n{}".format(
  1535                  src_name, tgt_name, imgdiff_output)
  1536  
  1537          with open(patch_info_file) as patch_info:
  1538            lines = patch_info.readlines()
  1539  
  1540          patch_size_total = os.path.getsize(patch_file)
  1541          split_info_list = ParseAndValidateSplitInfo(patch_size_total,
  1542                                                      tgt_ranges, src_ranges,
  1543                                                      lines)
  1544          for index, (patch_start, patch_length, split_tgt_ranges,
  1545                      split_src_ranges) in enumerate(split_info_list):
  1546            with open(patch_file) as f:
  1547              f.seek(patch_start)
  1548              patch_content = f.read(patch_length)
  1549  
  1550            split_src_name = "{}-{}".format(src_name, index)
  1551            split_tgt_name = "{}-{}".format(tgt_name, index)
  1552            split_large_apks.append((split_tgt_name,
  1553                                     split_src_name,
  1554                                     split_tgt_ranges,
  1555                                     split_src_ranges,
  1556                                     patch_content))
  1557  
  1558      print("Finding transfers...")
  1559  
  1560      large_apks = []
  1561      split_large_apks = []
  1562      cache_size = common.OPTIONS.cache_size
  1563      split_threshold = 0.125
  1564      max_blocks_per_transfer = int(cache_size * split_threshold /
  1565                                    self.tgt.blocksize)
  1566      empty = RangeSet()
  1567      for tgt_fn, tgt_ranges in sorted(self.tgt.file_map.items()):
  1568        if tgt_fn == "__ZERO":
  1569          # the special "__ZERO" domain is all the blocks not contained
  1570          # in any file and that are filled with zeros.  We have a
  1571          # special transfer style for zero blocks.
  1572          src_ranges = self.src.file_map.get("__ZERO", empty)
  1573          AddTransfer(tgt_fn, "__ZERO", tgt_ranges, src_ranges,
  1574                      "zero", self.transfers)
  1575          continue
  1576  
  1577        elif tgt_fn == "__COPY":
  1578          # "__COPY" domain includes all the blocks not contained in any
  1579          # file and that need to be copied unconditionally to the target.
  1580          AddTransfer(tgt_fn, None, tgt_ranges, empty, "new", self.transfers)
  1581          continue
  1582  
  1583        elif tgt_fn in self.src.file_map:
  1584          # Look for an exact pathname match in the source.
  1585          AddTransfer(tgt_fn, tgt_fn, tgt_ranges, self.src.file_map[tgt_fn],
  1586                      "diff", self.transfers, True)
  1587          continue
  1588  
  1589        b = os.path.basename(tgt_fn)
  1590        if b in self.src_basenames:
  1591          # Look for an exact basename match in the source.
  1592          src_fn = self.src_basenames[b]
  1593          AddTransfer(tgt_fn, src_fn, tgt_ranges, self.src.file_map[src_fn],
  1594                      "diff", self.transfers, True)
  1595          continue
  1596  
  1597        b = re.sub("[0-9]+", "#", b)
  1598        if b in self.src_numpatterns:
  1599          # Look for a 'number pattern' match (a basename match after
  1600          # all runs of digits are replaced by "#").  (This is useful
  1601          # for .so files that contain version numbers in the filename
  1602          # that get bumped.)
  1603          src_fn = self.src_numpatterns[b]
  1604          AddTransfer(tgt_fn, src_fn, tgt_ranges, self.src.file_map[src_fn],
  1605                      "diff", self.transfers, True)
  1606          continue
  1607  
  1608        AddTransfer(tgt_fn, None, tgt_ranges, empty, "new", self.transfers)
  1609  
  1610      transfer_lock = threading.Lock()
  1611      threads = [threading.Thread(target=SplitLargeApks)
  1612                 for _ in range(self.threads)]
  1613      for th in threads:
  1614        th.start()
  1615      while threads:
  1616        threads.pop().join()
  1617  
  1618      # Sort the split transfers for large apks to generate a determinate package.
  1619      split_large_apks.sort()
  1620      for (tgt_name, src_name, tgt_ranges, src_ranges,
  1621           patch) in split_large_apks:
  1622        transfer_split = Transfer(tgt_name, src_name, tgt_ranges, src_ranges,
  1623                                  self.tgt.RangeSha1(tgt_ranges),
  1624                                  self.src.RangeSha1(src_ranges),
  1625                                  "diff", self.transfers)
  1626        transfer_split.patch = patch
  1627  
  1628    def AbbreviateSourceNames(self):
  1629      for k in self.src.file_map.keys():
  1630        b = os.path.basename(k)
  1631        self.src_basenames[b] = k
  1632        b = re.sub("[0-9]+", "#", b)
  1633        self.src_numpatterns[b] = k
  1634  
  1635    @staticmethod
  1636    def AssertPartition(total, seq):
  1637      """Assert that all the RangeSets in 'seq' form a partition of the
  1638      'total' RangeSet (ie, they are nonintersecting and their union
  1639      equals 'total')."""
  1640  
  1641      so_far = RangeSet()
  1642      for i in seq:
  1643        assert not so_far.overlaps(i)
  1644        so_far = so_far.union(i)
  1645      assert so_far == total