refactor BlockDifference into common

Move BlockDifference into common and make its script generation code
more complete, so that it can be use by releasetools.py to do diffs on
baseband images.

Bug: 16984795
Change-Id: Iba9afc1c7755458ce47468b5170672612b2cb4b3
diff --git a/tools/releasetools/blockimgdiff.py b/tools/releasetools/blockimgdiff.py
index 7b01f71..6f94eda 100644
--- a/tools/releasetools/blockimgdiff.py
+++ b/tools/releasetools/blockimgdiff.py
@@ -14,6 +14,8 @@
 
 from rangelib import *
 
+__all__ = ["EmptyImage", "DataImage", "BlockImageDiff"]
+
 def compute_patch(src, tgt, imgdiff=False):
   srcfd, srcfile = tempfile.mkstemp(prefix="src-")
   tgtfd, tgtfile = tempfile.mkstemp(prefix="tgt-")
@@ -60,6 +62,59 @@
   file_map = {}
   def ReadRangeSet(self, ranges):
     return ()
+  def TotalSha1(self):
+    return sha1().hexdigest()
+
+
+class DataImage(object):
+  """An image wrapped around a single string of data."""
+
+  def __init__(self, data, trim=False, pad=False):
+    self.data = data
+    self.blocksize = 4096
+
+    assert not (trim and pad)
+
+    partial = len(self.data) % self.blocksize
+    if partial > 0:
+      if trim:
+        self.data = self.data[:-partial]
+      elif pad:
+        self.data += '\0' * (self.blocksize - partial)
+      else:
+        raise ValueError(("data for DataImage must be multiple of %d bytes "
+                          "unless trim or pad is specified") %
+                         (self.blocksize,))
+
+    assert len(self.data) % self.blocksize == 0
+
+    self.total_blocks = len(self.data) / self.blocksize
+    self.care_map = RangeSet(data=(0, self.total_blocks))
+
+    zero_blocks = []
+    nonzero_blocks = []
+    reference = '\0' * self.blocksize
+
+    for i in range(self.total_blocks):
+      d = self.data[i*self.blocksize : (i+1)*self.blocksize]
+      if d == reference:
+        zero_blocks.append(i)
+        zero_blocks.append(i+1)
+      else:
+        nonzero_blocks.append(i)
+        nonzero_blocks.append(i+1)
+
+    self.file_map = {"__ZERO": RangeSet(zero_blocks),
+                     "__NONZERO": RangeSet(nonzero_blocks)}
+
+  def ReadRangeSet(self, ranges):
+    return [self.data[s*self.blocksize:e*self.blocksize] for (s, e) in ranges]
+
+  def TotalSha1(self):
+    if not hasattr(self, "sha1"):
+      self.sha1 = sha1(self.data).hexdigest()
+    return self.sha1
+
 
 class Transfer(object):
   def __init__(self, tgt_name, src_name, tgt_ranges, src_ranges, style, by_id):
@@ -104,6 +159,10 @@
 #      Implementations are free to break up the data into list/tuple
 #      elements in any way that is convenient.
 #
+#    TotalSha1(): a function that returns (as a hex string) the SHA-1
+#      hash of all the data in the image (ie, all the blocks in the
+#      care_map)
+#
 # When creating a BlockImageDiff, the src image may be None, in which
 # case the list of transfers produced will never read from the
 # original image.
@@ -478,7 +537,12 @@
         # If the blocks written by A are read by B, then B needs to go before A.
         i = a.tgt_ranges.intersect(b.src_ranges)
         if i:
-          size = i.size()
+          if b.src_name == "__ZERO":
+            # the cost of removing source blocks for the __ZERO domain
+            # is (nearly) zero.
+            size = 0
+          else:
+            size = i.size()
           b.goes_before[a] = size
           a.goes_after[b] = size
 
@@ -491,7 +555,8 @@
         # in any file and that are filled with zeros.  We have a
         # special transfer style for zero blocks.
         src_ranges = self.src.file_map.get("__ZERO", empty)
-        Transfer(tgt_fn, None, tgt_ranges, src_ranges, "zero", self.transfers)
+        Transfer(tgt_fn, "__ZERO", tgt_ranges, src_ranges,
+                 "zero", self.transfers)
         continue
 
       elif tgt_fn in self.src.file_map:
diff --git a/tools/releasetools/common.py b/tools/releasetools/common.py
index 714f71b..83cda63 100644
--- a/tools/releasetools/common.py
+++ b/tools/releasetools/common.py
@@ -29,6 +29,8 @@
 import time
 import zipfile
 
+import blockimgdiff
+
 try:
   from hashlib import sha1 as sha1
 except ImportError:
@@ -1010,6 +1012,60 @@
     threads.pop().join()
 
 
+class BlockDifference:
+  def __init__(self, partition, tgt, src=None):
+    self.tgt = tgt
+    self.src = src
+    self.partition = partition
+
+    b = blockimgdiff.BlockImageDiff(tgt, src, threads=OPTIONS.worker_threads)
+    tmpdir = tempfile.mkdtemp()
+    OPTIONS.tempfiles.append(tmpdir)
+    self.path = os.path.join(tmpdir, partition)
+    b.Compute(self.path)
+
+    _, self.device = GetTypeAndDevice("/" + partition, OPTIONS.info_dict)
+
+  def WriteScript(self, script, output_zip, progress=None):
+    if not self.src:
+      # write the output unconditionally
+      if progress: script.ShowProgress(progress, 0)
+      self._WriteUpdate(script, output_zip)
+
+    else:
+      script.AppendExtra('if range_sha1("%s", "%s") == "%s" then' %
+                         (self.device, self.src.care_map.to_string_raw(),
+                          self.src.TotalSha1()))
+      script.Print("Patching %s image..." % (self.partition,))
+      if progress: script.ShowProgress(progress, 0)
+      self._WriteUpdate(script, output_zip)
+      script.AppendExtra(('else\n'
+                          '  (range_sha1("%s", "%s") == "%s") ||\n'
+                          '  abort("%s partition has unexpected contents");\n'
+                          'endif;') %
+                         (self.device, self.tgt.care_map.to_string_raw(),
+                          self.tgt.TotalSha1(), self.partition))
+
+  def _WriteUpdate(self, script, output_zip):
+    partition = self.partition
+    with open(self.path + ".transfer.list", "rb") as f:
+      ZipWriteStr(output_zip, partition + ".transfer.list", f.read())
+    with open(self.path + ".new.dat", "rb") as f:
+      ZipWriteStr(output_zip, partition + ".new.dat", f.read())
+    with open(self.path + ".patch.dat", "rb") as f:
+      ZipWriteStr(output_zip, partition + ".patch.dat", f.read(),
+                         compression=zipfile.ZIP_STORED)
+
+    call = (('block_image_update("%s", '
+             'package_extract_file("%s.transfer.list"), '
+             '"%s.new.dat", "%s.patch.dat");\n') %
+            (self.device, partition, partition, partition))
+    script.AppendExtra(script._WordWrap(call))
+
+
+DataImage = blockimgdiff.DataImage
+
+
 # map recovery.fstab's fs_types to mount/format "partition types"
 PARTITION_TYPES = { "yaffs2": "MTD", "mtd": "MTD",
                     "ext4": "EMMC", "emmc": "EMMC",
diff --git a/tools/releasetools/ota_from_target_files b/tools/releasetools/ota_from_target_files
index bcc3210..8b7342b 100755
--- a/tools/releasetools/ota_from_target_files
+++ b/tools/releasetools/ota_from_target_files
@@ -455,35 +455,6 @@
   return sparse_img.SparseImage(path, mappath)
 
 
-class BlockDifference:
-  def __init__(self, partition, tgt, src=None):
-    self.partition = partition
-
-    b = blockimgdiff.BlockImageDiff(tgt, src, threads=OPTIONS.worker_threads)
-    tmpdir = tempfile.mkdtemp()
-    OPTIONS.tempfiles.append(tmpdir)
-    self.path = os.path.join(tmpdir, partition)
-    b.Compute(self.path)
-
-    _, self.device = common.GetTypeAndDevice("/" + partition, OPTIONS.info_dict)
-
-  def WriteScript(self, script, output_zip):
-    partition = self.partition
-    with open(self.path + ".transfer.list", "rb") as f:
-      common.ZipWriteStr(output_zip, partition + ".transfer.list", f.read())
-    with open(self.path + ".new.dat", "rb") as f:
-      common.ZipWriteStr(output_zip, partition + ".new.dat", f.read())
-    with open(self.path + ".patch.dat", "rb") as f:
-      common.ZipWriteStr(output_zip, partition + ".patch.dat", f.read(),
-                         compression=zipfile.ZIP_STORED)
-
-    call = (('block_image_update("%s", '
-             'package_extract_file("%s.transfer.list"), '
-             '"%s.new.dat", "%s.patch.dat");\n') %
-            (self.device, partition, partition, partition))
-    script.AppendExtra(script._WordWrap(call))
-
-
 def WriteFullOTAPackage(input_zip, output_zip):
   # TODO: how to determine this?  We don't know what version it will
   # be installed on top of.  For now, we expect the API just won't
@@ -586,7 +557,7 @@
     # writes incrementals to do it.
     system_tgt = GetImage("system", OPTIONS.input_tmp, OPTIONS.info_dict)
     system_tgt.ResetFileMap()
-    system_diff = BlockDifference("system", system_tgt, src=None)
+    system_diff = common.BlockDifference("system", system_tgt, src=None)
     system_diff.WriteScript(script, output_zip)
   else:
     script.FormatPartition("/system")
@@ -619,7 +590,7 @@
     if block_based:
       vendor_tgt = GetImage("vendor", OPTIONS.input_tmp, OPTIONS.info_dict)
       vendor_tgt.ResetFileMap()
-      vendor_diff = BlockDifference("vendor", vendor_tgt)
+      vendor_diff = common.BlockDifference("vendor", vendor_tgt)
       vendor_diff.WriteScript(script, output_zip)
     else:
       script.FormatPartition("/vendor")
@@ -760,14 +731,14 @@
 
   system_src = GetImage("system", OPTIONS.source_tmp, OPTIONS.source_info_dict)
   system_tgt = GetImage("system", OPTIONS.target_tmp, OPTIONS.target_info_dict)
-  system_diff = BlockDifference("system", system_tgt, system_src)
+  system_diff = common.BlockDifference("system", system_tgt, system_src)
 
   if HasVendorPartition(target_zip):
     if not HasVendorPartition(source_zip):
       raise RuntimeError("can't generate incremental that adds /vendor")
     vendor_src = GetImage("vendor", OPTIONS.source_tmp, OPTIONS.source_info_dict)
     vendor_tgt = GetImage("vendor", OPTIONS.target_tmp, OPTIONS.target_info_dict)
-    vendor_diff = BlockDifference("vendor", vendor_tgt, vendor_src)
+    vendor_diff = common.BlockDifference("vendor", vendor_tgt, vendor_src)
   else:
     vendor_diff = None
 
@@ -867,32 +838,10 @@
 
   device_specific.IncrementalOTA_InstallBegin()
 
-  script.AppendExtra('if range_sha1("%s", "%s") == "%s" then' %
-                     (system_diff.device, system_src.care_map.to_string_raw(),
-                      system_src.TotalSha1()))
-  script.Print("Patching system image...")
-  script.ShowProgress(0.8 if vendor_diff else 0.9, 0)
-  system_diff.WriteScript(script, output_zip)
-  script.AppendExtra(('else\n'
-                      '  (range_sha1("%s", "%s") == "%s") ||\n'
-                      '  abort("system partition has unexpected contents");\n'
-                      'endif;') %
-                     (system_diff.device, system_tgt.care_map.to_string_raw(),
-                      system_tgt.TotalSha1()))
-
+  system_diff.WriteScript(script, output_zip,
+                          progress=0.8 if vendor_diff else 0.9)
   if vendor_diff:
-    script.AppendExtra('if range_sha1("%s", "%s") == "%s" then' %
-                       (vendor_diff.device, vendor_src.care_map.to_string_raw(),
-                        vendor_src.TotalSha1()))
-    script.Print("Patching vendor image...")
-    script.ShowProgress(0.1, 0)
-    vendor_diff.WriteScript(script, output_zip)
-    script.AppendExtra(('else\n'
-                        '  (range_sha1("%s", "%s") == "%s") ||\n'
-                        '  abort("vendor partition has unexpected contents");\n'
-                        'endif;') %
-                       (vendor_diff.device, vendor_tgt.care_map.to_string_raw(),
-                        vendor_tgt.TotalSha1()))
+    vendor_diff.WriteScript(script, output_zip, progress=0.1)
 
   if OPTIONS.two_step:
     common.ZipWriteStr(output_zip, "boot.img", target_boot.data)