releasetools: Reduce memory footprint for BBOTA generation.

The major issue with the existing implementation is unnecessarily
holding too much data in memory, such as HashBlocks() which first reads
in *all* the data to a list before hashing. We can leverage generator
functions to stream such operations.

This CL makes the following changes to reduce the peak memory use.
 - Adding RangeSha1() and WriteRangeDataToFd() to Image classes. These
   functions perform the operations on-the-fly.
 - Caching the computed SHA-1 values for a Transfer instance.

As a result, this CL reduces the peak memory use by ~80% (e.g. reducing
from 5.85GB to 1.16GB for the same incremental, as shown by "Maximum
resident set size" from `/usr/bin/time -v`). It also effectively
improves the (package generation) performance by ~30%.

Bug: 35768998
Bug: 32312123
Test: Generating the same incremental w/ and w/o the CL give identical
      output packages.
Change-Id: Ia5c6314b41da73dd6fe1dbe2ca81bbd89b517cec
diff --git a/tools/releasetools/blockimgdiff.py b/tools/releasetools/blockimgdiff.py
index c204c90..6c842dc 100644
--- a/tools/releasetools/blockimgdiff.py
+++ b/tools/releasetools/blockimgdiff.py
@@ -24,8 +24,8 @@
 import os.path
 import re
 import subprocess
+import sys
 import threading
-import tempfile
 
 from collections import deque, OrderedDict
 from hashlib import sha1
@@ -35,69 +35,67 @@
 __all__ = ["EmptyImage", "DataImage", "BlockImageDiff"]
 
 
-def compute_patch(src, tgt, imgdiff=False):
-  srcfd, srcfile = tempfile.mkstemp(prefix="src-")
-  tgtfd, tgtfile = tempfile.mkstemp(prefix="tgt-")
-  patchfd, patchfile = tempfile.mkstemp(prefix="patch-")
-  os.close(patchfd)
+def compute_patch(srcfile, tgtfile, imgdiff=False):
+  patchfile = common.MakeTempFile(prefix="patch-")
 
-  try:
-    with os.fdopen(srcfd, "wb") as f_src:
-      for p in src:
-        f_src.write(p)
+  if imgdiff:
+    p = subprocess.call(
+        ["imgdiff", "-z", srcfile, tgtfile, patchfile],
+        stdout=open(os.devnull, 'w'),
+        stderr=subprocess.STDOUT)
+  else:
+    p = subprocess.call(
+        ["bsdiff", srcfile, tgtfile, patchfile],
+        stdout=open(os.devnull, 'w'),
+        stderr=subprocess.STDOUT)
 
-    with os.fdopen(tgtfd, "wb") as f_tgt:
-      for p in tgt:
-        f_tgt.write(p)
-    try:
-      os.unlink(patchfile)
-    except OSError:
-      pass
-    if imgdiff:
-      p = subprocess.call(["imgdiff", "-z", srcfile, tgtfile, patchfile],
-                          stdout=open("/dev/null", "a"),
-                          stderr=subprocess.STDOUT)
-    else:
-      p = subprocess.call(["bsdiff", srcfile, tgtfile, patchfile])
+  if p:
+    raise ValueError("diff failed: " + str(p))
 
-    if p:
-      raise ValueError("diff failed: " + str(p))
-
-    with open(patchfile, "rb") as f:
-      return f.read()
-  finally:
-    try:
-      os.unlink(srcfile)
-      os.unlink(tgtfile)
-      os.unlink(patchfile)
-    except OSError:
-      pass
+  with open(patchfile, "rb") as f:
+    return f.read()
 
 
 class Image(object):
+  def RangeSha1(self, ranges):
+    raise NotImplementedError
+
   def ReadRangeSet(self, ranges):
     raise NotImplementedError
 
   def TotalSha1(self, include_clobbered_blocks=False):
     raise NotImplementedError
 
+  def WriteRangeDataToFd(self, ranges, fd):
+    raise NotImplementedError
+
 
 class EmptyImage(Image):
   """A zero-length image."""
-  blocksize = 4096
-  care_map = RangeSet()
-  clobbered_blocks = RangeSet()
-  extended = RangeSet()
-  total_blocks = 0
-  file_map = {}
+
+  def __init__(self):
+    self.blocksize = 4096
+    self.care_map = RangeSet()
+    self.clobbered_blocks = RangeSet()
+    self.extended = RangeSet()
+    self.total_blocks = 0
+    self.file_map = {}
+
+  def RangeSha1(self, ranges):
+    return sha1().hexdigest()
+
   def ReadRangeSet(self, ranges):
     return ()
+
   def TotalSha1(self, include_clobbered_blocks=False):
     # EmptyImage always carries empty clobbered_blocks, so
     # include_clobbered_blocks can be ignored.
     assert self.clobbered_blocks.size() == 0
     return sha1().hexdigest()
 
+  def WriteRangeDataToFd(self, ranges, fd):
+    raise ValueError("Can't write data from EmptyImage to file")
+
 
 class DataImage(Image):
   """An image wrapped around a single string of data."""
@@ -160,23 +158,39 @@
     if clobbered_blocks:
       self.file_map["__COPY"] = RangeSet(data=clobbered_blocks)
 
+  def _GetRangeData(self, ranges):
+    for s, e in ranges:
+      yield self.data[s*self.blocksize:e*self.blocksize]
+
+  def RangeSha1(self, ranges):
+    h = sha1()
+    for data in self._GetRangeData(ranges):
+      h.update(data)
+    return h.hexdigest()
+
   def ReadRangeSet(self, ranges):
-    return [self.data[s*self.blocksize:e*self.blocksize] for (s, e) in ranges]
+    return [self._GetRangeData(ranges)]
 
   def TotalSha1(self, include_clobbered_blocks=False):
     if not include_clobbered_blocks:
-      ranges = self.care_map.subtract(self.clobbered_blocks)
-      return sha1(self.ReadRangeSet(ranges)).hexdigest()
+      return self.RangeSha1(self.care_map.subtract(self.clobbered_blocks))
     else:
       return sha1(self.data).hexdigest()
 
+  def WriteRangeDataToFd(self, ranges, fd):
+    for data in self._GetRangeData(ranges):
+      fd.write(data)
+
 
 class Transfer(object):
-  def __init__(self, tgt_name, src_name, tgt_ranges, src_ranges, style, by_id):
+  def __init__(self, tgt_name, src_name, tgt_ranges, src_ranges, tgt_sha1,
+               src_sha1, style, by_id):
     self.tgt_name = tgt_name
     self.src_name = src_name
     self.tgt_ranges = tgt_ranges
     self.src_ranges = src_ranges
+    self.tgt_sha1 = tgt_sha1
+    self.src_sha1 = src_sha1
     self.style = style
     self.intact = (getattr(tgt_ranges, "monotonic", False) and
                    getattr(src_ranges, "monotonic", False))
@@ -251,6 +265,9 @@
 #      Implementations are free to break up the data into list/tuple
 #      elements in any way that is convenient.
 #
+#    RangeSha1(): a function that returns (as a hex string) the SHA-1
+#      hash of all the data in the specified range.
+#
 #    TotalSha1(): a function that returns (as a hex string) the SHA-1
 #      hash of all the data in the image (ie, all the blocks in the
 #      care_map minus clobbered_blocks, or including the clobbered
@@ -332,15 +349,6 @@
     self.ComputePatches(prefix)
     self.WriteTransfers(prefix)
 
-  def HashBlocks(self, source, ranges): # pylint: disable=no-self-use
-    data = source.ReadRangeSet(ranges)
-    ctx = sha1()
-
-    for p in data:
-      ctx.update(p)
-
-    return ctx.hexdigest()
-
   def WriteTransfers(self, prefix):
     def WriteSplitTransfers(out, style, target_blocks):
       """Limit the size of operand in command 'new' and 'zero' to 1024 blocks.
@@ -397,7 +405,7 @@
           stashed_blocks += sr.size()
           out.append("stash %d %s\n" % (sid, sr.to_string_raw()))
         else:
-          sh = self.HashBlocks(self.src, sr)
+          sh = self.src.RangeSha1(sr)
           if sh in stashes:
             stashes[sh] += 1
           else:
@@ -429,7 +437,7 @@
         mapped_stashes = []
         for stash_raw_id, sr in xf.use_stash:
           unstashed_src_ranges = unstashed_src_ranges.subtract(sr)
-          sh = self.HashBlocks(self.src, sr)
+          sh = self.src.RangeSha1(sr)
           sr = xf.src_ranges.map_within(sr)
           mapped_stashes.append(sr)
           if self.version == 2:
@@ -515,7 +523,7 @@
 
             out.append("%s %s %s %s\n" % (
                 xf.style,
-                self.HashBlocks(self.tgt, xf.tgt_ranges),
+                xf.tgt_sha1,
                 xf.tgt_ranges.to_string_raw(), src_str))
           total += tgt_size
       elif xf.style in ("bsdiff", "imgdiff"):
@@ -542,8 +550,8 @@
           out.append("%s %d %d %s %s %s %s\n" % (
               xf.style,
               xf.patch_start, xf.patch_len,
-              self.HashBlocks(self.src, xf.src_ranges),
-              self.HashBlocks(self.tgt, xf.tgt_ranges),
+              xf.src_sha1,
+              xf.tgt_sha1,
               xf.tgt_ranges.to_string_raw(), src_str))
         total += tgt_size
       elif xf.style == "zero":
@@ -574,8 +582,7 @@
                    stash_threshold)
 
     if self.version >= 3:
-      self.touched_src_sha1 = self.HashBlocks(
-          self.src, self.touched_src_ranges)
+      self.touched_src_sha1 = self.src.RangeSha1(self.touched_src_ranges)
 
     # Zero out extended blocks as a workaround for bug 20881595.
     if self.tgt.extended:
@@ -674,7 +681,7 @@
         if self.version == 2:
           stashed_blocks_after += sr.size()
         else:
-          sh = self.HashBlocks(self.src, sr)
+          sh = self.src.RangeSha1(sr)
           if sh not in stashes:
             stashed_blocks_after += sr.size()
 
@@ -731,7 +738,7 @@
           stashed_blocks -= sr.size()
           heapq.heappush(free_stash_ids, sid)
         else:
-          sh = self.HashBlocks(self.src, sr)
+          sh = self.src.RangeSha1(sr)
           assert sh in stashes
           stashes[sh] -= 1
           if stashes[sh] == 0:
@@ -745,10 +752,10 @@
 
   def ComputePatches(self, prefix):
     print("Reticulating splines...")
-    diff_q = []
+    diff_queue = []
     patch_num = 0
     with open(prefix + ".new.dat", "wb") as new_f:
-      for xf in self.transfers:
+      for index, xf in enumerate(self.transfers):
         if xf.style == "zero":
           tgt_size = xf.tgt_ranges.size() * self.tgt.blocksize
           print("%10d %10d (%6.2f%%) %7s %s %s" % (
@@ -756,17 +763,13 @@
               str(xf.tgt_ranges)))
 
         elif xf.style == "new":
-          for piece in self.tgt.ReadRangeSet(xf.tgt_ranges):
-            new_f.write(piece)
+          self.tgt.WriteRangeDataToFd(xf.tgt_ranges, new_f)
           tgt_size = xf.tgt_ranges.size() * self.tgt.blocksize
           print("%10d %10d (%6.2f%%) %7s %s %s" % (
               tgt_size, tgt_size, 100.0, xf.style,
               xf.tgt_name, str(xf.tgt_ranges)))
 
         elif xf.style == "diff":
-          src = self.src.ReadRangeSet(xf.src_ranges)
-          tgt = self.tgt.ReadRangeSet(xf.tgt_ranges)
-
           # We can't compare src and tgt directly because they may have
           # the same content but be broken up into blocks differently, eg:
           #
@@ -775,20 +778,11 @@
           # We want those to compare equal, ideally without having to
           # actually concatenate the strings (these may be tens of
           # megabytes).
-
-          src_sha1 = sha1()
-          for p in src:
-            src_sha1.update(p)
-          tgt_sha1 = sha1()
-          tgt_size = 0
-          for p in tgt:
-            tgt_sha1.update(p)
-            tgt_size += len(p)
-
-          if src_sha1.digest() == tgt_sha1.digest():
+          if xf.src_sha1 == xf.tgt_sha1:
             # These are identical; we don't need to generate a patch,
             # just issue copy commands on the device.
             xf.style = "move"
+            tgt_size = xf.tgt_ranges.size() * self.tgt.blocksize
             if xf.src_ranges != xf.tgt_ranges:
               print("%10d %10d (%6.2f%%) %7s %s %s (from %s)" % (
                   tgt_size, tgt_size, 100.0, xf.style,
@@ -815,38 +809,64 @@
                        xf.tgt_name.split(".")[-1].lower()
                        in ("apk", "jar", "zip"))
             xf.style = "imgdiff" if imgdiff else "bsdiff"
-            diff_q.append((tgt_size, src, tgt, xf, patch_num))
+            diff_queue.append((index, imgdiff, patch_num))
             patch_num += 1
 
         else:
           assert False, "unknown style " + xf.style
 
-    if diff_q:
+    if diff_queue:
       if self.threads > 1:
         print("Computing patches (using %d threads)..." % (self.threads,))
       else:
         print("Computing patches...")
-      diff_q.sort()
 
-      patches = [None] * patch_num
+      diff_total = len(diff_queue)
+      patches = [None] * diff_total
 
-      # TODO: Rewrite with multiprocessing.ThreadPool?
+      # Using multiprocessing doesn't give additional benefits, due to the
+      # pattern of the code. The diffing work is done by subprocess.call, which
+      # already runs in a separate process (not affected much by the GIL -
+      # Global Interpreter Lock). Using multiprocess also requires either a)
+      # writing the diff input files in the main process before forking, or b)
+      # reopening the image file (SparseImage) in the worker processes. Doing
+      # neither of them further improves the performance.
       lock = threading.Lock()
       def diff_worker():
         while True:
           with lock:
-            if not diff_q:
+            if not diff_queue:
               return
-            tgt_size, src, tgt, xf, patchnum = diff_q.pop()
-          patch = compute_patch(src, tgt, imgdiff=(xf.style == "imgdiff"))
-          size = len(patch)
+            xf_index, imgdiff, patch_index = diff_queue.pop()
+
+          xf = self.transfers[xf_index]
+          src_ranges = xf.src_ranges
+          tgt_ranges = xf.tgt_ranges
+
+          # Needs lock since WriteRangeDataToFd() is stateful (calling seek).
           with lock:
-            patches[patchnum] = (patch, xf)
-            print("%10d %10d (%6.2f%%) %7s %s %s %s" % (
-                size, tgt_size, size * 100.0 / tgt_size, xf.style,
-                xf.tgt_name if xf.tgt_name == xf.src_name else (
-                    xf.tgt_name + " (from " + xf.src_name + ")"),
-                str(xf.tgt_ranges), str(xf.src_ranges)))
+            src_file = common.MakeTempFile(prefix="src-")
+            with open(src_file, "wb") as fd:
+              self.src.WriteRangeDataToFd(src_ranges, fd)
+
+            tgt_file = common.MakeTempFile(prefix="tgt-")
+            with open(tgt_file, "wb") as fd:
+              self.tgt.WriteRangeDataToFd(tgt_ranges, fd)
+
+          try:
+            patch = compute_patch(src_file, tgt_file, imgdiff)
+          except ValueError as e:
+            raise ValueError(
+                "Failed to generate diff for %s: src=%s, tgt=%s: %s" % (
+                    xf.tgt_name, xf.src_ranges, xf.tgt_ranges, e.message))
+
+          with lock:
+            patches[patch_index] = (xf_index, patch)
+            if sys.stdout.isatty():
+              progress = len(patches) * 100 / diff_total
+              # '\033[K' is to clear to EOL.
+              print(' [%d%%] %s\033[K' % (progress, xf.tgt_name), end='\r')
+              sys.stdout.flush()
 
       threads = [threading.Thread(target=diff_worker)
                  for _ in range(self.threads)]
@@ -854,16 +874,29 @@
         th.start()
       while threads:
         threads.pop().join()
+
+      if sys.stdout.isatty():
+        print('\n')
     else:
       patches = []
 
-    p = 0
-    with open(prefix + ".patch.dat", "wb") as patch_f:
-      for patch, xf in patches:
-        xf.patch_start = p
+    offset = 0
+    with open(prefix + ".patch.dat", "wb") as patch_fd:
+      for index, patch in patches:
+        xf = self.transfers[index]
         xf.patch_len = len(patch)
-        patch_f.write(patch)
-        p += len(patch)
+        xf.patch_start = offset
+        offset += xf.patch_len
+        patch_fd.write(patch)
+
+        if common.OPTIONS.verbose:
+          tgt_size = xf.tgt_ranges.size() * self.tgt.blocksize
+          print("%10d %10d (%6.2f%%) %7s %s %s %s" % (
+                xf.patch_len, tgt_size, xf.patch_len * 100.0 / tgt_size,
+                xf.style,
+                xf.tgt_name if xf.tgt_name == xf.src_name else (
+                    xf.tgt_name + " (from " + xf.src_name + ")"),
+                xf.tgt_ranges, xf.src_ranges))
 
   def AssertSequenceGood(self):
     # Simulate the sequences of transfers we will output, and check that:
@@ -1211,7 +1244,9 @@
       # Change nothing for small files.
       if (tgt_ranges.size() <= max_blocks_per_transfer and
           src_ranges.size() <= max_blocks_per_transfer):
-        Transfer(tgt_name, src_name, tgt_ranges, src_ranges, style, by_id)
+        Transfer(tgt_name, src_name, tgt_ranges, src_ranges,
+                 self.tgt.RangeSha1(tgt_ranges), self.src.RangeSha1(src_ranges),
+                 style, by_id)
         return
 
       while (tgt_ranges.size() > max_blocks_per_transfer and
@@ -1221,8 +1256,9 @@
         tgt_first = tgt_ranges.first(max_blocks_per_transfer)
         src_first = src_ranges.first(max_blocks_per_transfer)
 
-        Transfer(tgt_split_name, src_split_name, tgt_first, src_first, style,
-                 by_id)
+        Transfer(tgt_split_name, src_split_name, tgt_first, src_first,
+                 self.tgt.RangeSha1(tgt_first), self.src.RangeSha1(src_first),
+                 style, by_id)
 
         tgt_ranges = tgt_ranges.subtract(tgt_first)
         src_ranges = src_ranges.subtract(src_first)
@@ -1234,8 +1270,9 @@
         assert tgt_ranges.size() and src_ranges.size()
         tgt_split_name = "%s-%d" % (tgt_name, pieces)
         src_split_name = "%s-%d" % (src_name, pieces)
-        Transfer(tgt_split_name, src_split_name, tgt_ranges, src_ranges, style,
-                 by_id)
+        Transfer(tgt_split_name, src_split_name, tgt_ranges, src_ranges,
+                 self.tgt.RangeSha1(tgt_ranges), self.src.RangeSha1(src_ranges),
+                 style, by_id)
 
     def AddTransfer(tgt_name, src_name, tgt_ranges, src_ranges, style, by_id,
                     split=False):
@@ -1244,7 +1281,9 @@
       # We specialize diff transfers only (which covers bsdiff/imgdiff/move);
       # otherwise add the Transfer() as is.
       if style != "diff" or not split:
-        Transfer(tgt_name, src_name, tgt_ranges, src_ranges, style, by_id)
+        Transfer(tgt_name, src_name, tgt_ranges, src_ranges,
+                 self.tgt.RangeSha1(tgt_ranges), self.src.RangeSha1(src_ranges),
+                 style, by_id)
         return
 
       # Handle .odex files specially to analyze the block-wise difference. If
diff --git a/tools/releasetools/sparse_img.py b/tools/releasetools/sparse_img.py
index 4ba7560..7eb60d9 100644
--- a/tools/releasetools/sparse_img.py
+++ b/tools/releasetools/sparse_img.py
@@ -144,6 +144,12 @@
     f.seek(16, os.SEEK_SET)
     f.write(struct.pack("<2I", self.total_blocks, self.total_chunks))
 
+  def RangeSha1(self, ranges):
+    h = sha1()
+    for data in self._GetRangeData(ranges):
+      h.update(data)
+    return h.hexdigest()
+
   def ReadRangeSet(self, ranges):
     return [d for d in self._GetRangeData(ranges)]
 
@@ -155,10 +161,11 @@
     ranges = self.care_map
     if not include_clobbered_blocks:
       ranges = ranges.subtract(self.clobbered_blocks)
-    h = sha1()
-    for d in self._GetRangeData(ranges):
-      h.update(d)
-    return h.hexdigest()
+    return self.RangeSha1(ranges)
+
+  def WriteRangeDataToFd(self, ranges, fd):
+    for data in self._GetRangeData(ranges):
+      fd.write(data)
 
   def _GetRangeData(self, ranges):
     """Generator that produces all the image data in 'ranges'.  The