Allow generating OTA package from non-sparse images.

Test: build OTA package in cuttlefish

Bug: 120041578
Change-Id: I246c38e08376c837b7f126aa19cb8c1d73ed1e26
diff --git a/tools/releasetools/blockimgdiff.py b/tools/releasetools/blockimgdiff.py
index b7c33f5..2e5f804 100644
--- a/tools/releasetools/blockimgdiff.py
+++ b/tools/releasetools/blockimgdiff.py
@@ -187,6 +187,78 @@
       fd.write(data)
 
 
+class FileImage(Image):
+  """An image wrapped around a raw image file."""
+
+  def __init__(self, path, hashtree_info_generator=None):
+    self.path = path
+    self.blocksize = 4096
+    self._file_size = os.path.getsize(self.path)
+    self._file = open(self.path, 'r')
+
+    if self._file_size % self.blocksize != 0:
+      raise ValueError("Size of file %s must be multiple of %d bytes, but is %d"
+                       % self.path, self.blocksize, self._file_size)
+
+    self.total_blocks = self._file_size / self.blocksize
+    self.care_map = RangeSet(data=(0, self.total_blocks))
+    self.clobbered_blocks = RangeSet()
+    self.extended = RangeSet()
+
+    self.hashtree_info = None
+    if hashtree_info_generator:
+      self.hashtree_info = hashtree_info_generator.Generate(self)
+
+    zero_blocks = []
+    nonzero_blocks = []
+    reference = '\0' * self.blocksize
+
+    for i in range(self.total_blocks):
+      d = self._file.read(self.blocksize)
+      if d == reference:
+        zero_blocks.append(i)
+        zero_blocks.append(i+1)
+      else:
+        nonzero_blocks.append(i)
+        nonzero_blocks.append(i+1)
+
+    assert zero_blocks or nonzero_blocks
+
+    self.file_map = {}
+    if zero_blocks:
+      self.file_map["__ZERO"] = RangeSet(data=zero_blocks)
+    if nonzero_blocks:
+      self.file_map["__NONZERO"] = RangeSet(data=nonzero_blocks)
+    if self.hashtree_info:
+      self.file_map["__HASHTREE"] = self.hashtree_info.hashtree_range
+
+  def __del__(self):
+    self._file.close()
+
+  def _GetRangeData(self, ranges):
+    for s, e in ranges:
+      self._file.seek(s * self.blocksize)
+      for _ in range(s, e):
+        yield self._file.read(self.blocksize)
+
+  def RangeSha1(self, ranges):
+    h = sha1()
+    for data in self._GetRangeData(ranges): # pylint: disable=not-an-iterable
+      h.update(data)
+    return h.hexdigest()
+
+  def ReadRangeSet(self, ranges):
+    return list(self._GetRangeData(ranges))
+
+  def TotalSha1(self, include_clobbered_blocks=False):
+    assert not self.clobbered_blocks
+    return self.RangeSha1(self.care_map)
+
+  def WriteRangeDataToFd(self, ranges, fd):
+    for data in self._GetRangeData(ranges): # pylint: disable=not-an-iterable
+      fd.write(data)
+
+
 class Transfer(object):
   def __init__(self, tgt_name, src_name, tgt_ranges, src_ranges, tgt_sha1,
                src_sha1, style, by_id):
diff --git a/tools/releasetools/common.py b/tools/releasetools/common.py
index 171c794..3e2a113 100644
--- a/tools/releasetools/common.py
+++ b/tools/releasetools/common.py
@@ -824,6 +824,77 @@
   return tmp
 
 
+def GetUserImage(which, tmpdir, input_zip,
+                 info_dict=None,
+                 allow_shared_blocks=None,
+                 hashtree_info_generator=None,
+                 reset_file_map=False):
+  """Returns an Image object suitable for passing to BlockImageDiff.
+
+  This function loads the specified image from the given path. If the specified
+  image is sparse, it also performs additional processing for OTA purpose. For
+  example, it always adds block 0 to clobbered blocks list. It also detects
+  files that cannot be reconstructed from the block list, for whom we should
+  avoid applying imgdiff.
+
+  Args:
+    which: The partition name.
+    tmpdir: The directory that contains the prebuilt image and block map file.
+    input_zip: The target-files ZIP archive.
+    info_dict: The dict to be looked up for relevant info.
+    allow_shared_blocks: If image is sparse, whether having shared blocks is
+        allowed. If none, it is looked up from info_dict.
+    hashtree_info_generator: If present and image is sparse, generates the
+        hashtree_info for this sparse image.
+    reset_file_map: If true and image is sparse, reset file map before returning
+        the image.
+  Returns:
+    A Image object. If it is a sparse image and reset_file_map is False, the
+    image will have file_map info loaded.
+  """
+  if info_dict == None:
+    info_dict = LoadInfoDict(input_zip)
+
+  is_sparse = info_dict.get("extfs_sparse_flag")
+
+  # When target uses 'BOARD_EXT4_SHARE_DUP_BLOCKS := true', images may contain
+  # shared blocks (i.e. some blocks will show up in multiple files' block
+  # list). We can only allocate such shared blocks to the first "owner", and
+  # disable imgdiff for all later occurrences.
+  if allow_shared_blocks is None:
+    allow_shared_blocks = info_dict.get("ext4_share_dup_blocks") == "true"
+
+  if is_sparse:
+    img = GetSparseImage(which, tmpdir, input_zip, allow_shared_blocks,
+                         hashtree_info_generator)
+    if reset_file_map:
+      img.ResetFileMap()
+    return img
+  else:
+    return GetNonSparseImage(which, tmpdir, hashtree_info_generator)
+
+
+def GetNonSparseImage(which, tmpdir, hashtree_info_generator=None):
+  """Returns a Image object suitable for passing to BlockImageDiff.
+
+  This function loads the specified non-sparse image from the given path.
+
+  Args:
+    which: The partition name.
+    tmpdir: The directory that contains the prebuilt image and block map file.
+  Returns:
+    A Image object.
+  """
+  path = os.path.join(tmpdir, "IMAGES", which + ".img")
+  mappath = os.path.join(tmpdir, "IMAGES", which + ".map")
+
+  # The image and map files must have been created prior to calling
+  # ota_from_target_files.py (since LMP).
+  assert os.path.exists(path) and os.path.exists(mappath)
+
+  return blockimgdiff.FileImage(path, hashtree_info_generator=
+                                hashtree_info_generator)
+
 def GetSparseImage(which, tmpdir, input_zip, allow_shared_blocks,
                    hashtree_info_generator=None):
   """Returns a SparseImage object suitable for passing to BlockImageDiff.
@@ -2068,7 +2139,7 @@
 
 
 DataImage = blockimgdiff.DataImage
-
+EmptyImage = blockimgdiff.EmptyImage
 
 # map recovery.fstab's fs_types to mount/format "partition types"
 PARTITION_TYPES = {
diff --git a/tools/releasetools/ota_from_target_files.py b/tools/releasetools/ota_from_target_files.py
index ad40bd4..dd3e190 100755
--- a/tools/releasetools/ota_from_target_files.py
+++ b/tools/releasetools/ota_from_target_files.py
@@ -917,17 +917,14 @@
 
   script.ShowProgress(system_progress, 0)
 
-  # See the notes in WriteBlockIncrementalOTAPackage().
-  allow_shared_blocks = target_info.get('ext4_share_dup_blocks') == "true"
-
   def GetBlockDifference(partition):
     # Full OTA is done as an "incremental" against an empty source image. This
     # has the effect of writing new data from the package to the entire
     # partition, but lets us reuse the updater code that writes incrementals to
     # do it.
-    tgt = common.GetSparseImage(partition, OPTIONS.input_tmp, input_zip,
-                                allow_shared_blocks)
-    tgt.ResetFileMap()
+    tgt = common.GetUserImage(partition, OPTIONS.input_tmp, input_zip,
+                              info_dict=target_info,
+                              reset_file_map=True)
     diff = common.BlockDifference(partition, tgt, src=None)
     return diff
 
@@ -1512,8 +1509,10 @@
   device_specific = common.DeviceSpecificParams(
       source_zip=source_zip,
       source_version=source_api_version,
+      source_tmp=OPTIONS.source_tmp,
       target_zip=target_zip,
       target_version=target_api_version,
+      target_tmp=OPTIONS.target_tmp,
       output_zip=output_zip,
       script=script,
       metadata=metadata,
@@ -1529,20 +1528,20 @@
   target_recovery = common.GetBootableImage(
       "/tmp/recovery.img", "recovery.img", OPTIONS.target_tmp, "RECOVERY")
 
-  # When target uses 'BOARD_EXT4_SHARE_DUP_BLOCKS := true', images may contain
-  # shared blocks (i.e. some blocks will show up in multiple files' block
-  # list). We can only allocate such shared blocks to the first "owner", and
-  # disable imgdiff for all later occurrences.
+  # See notes in common.GetUserImage()
   allow_shared_blocks = (source_info.get('ext4_share_dup_blocks') == "true" or
                          target_info.get('ext4_share_dup_blocks') == "true")
-  system_src = common.GetSparseImage("system", OPTIONS.source_tmp, source_zip,
-                                     allow_shared_blocks)
+  system_src = common.GetUserImage("system", OPTIONS.source_tmp, source_zip,
+                                   info_dict=source_info,
+                                   allow_shared_blocks=allow_shared_blocks)
 
   hashtree_info_generator = verity_utils.CreateHashtreeInfoGenerator(
       "system", 4096, target_info)
-  system_tgt = common.GetSparseImage("system", OPTIONS.target_tmp, target_zip,
-                                     allow_shared_blocks,
-                                     hashtree_info_generator)
+  system_tgt = common.GetUserImage("system", OPTIONS.target_tmp, target_zip,
+                                   info_dict=target_info,
+                                   allow_shared_blocks=allow_shared_blocks,
+                                   hashtree_info_generator=
+                                   hashtree_info_generator)
 
   blockimgdiff_version = max(
       int(i) for i in target_info.get("blockimgdiff_versions", "1").split(","))
@@ -1567,13 +1566,16 @@
   if HasVendorPartition(target_zip):
     if not HasVendorPartition(source_zip):
       raise RuntimeError("can't generate incremental that adds /vendor")
-    vendor_src = common.GetSparseImage("vendor", OPTIONS.source_tmp, source_zip,
-                                       allow_shared_blocks)
+    vendor_src = common.GetUserImage("vendor", OPTIONS.source_tmp, source_zip,
+                                     info_dict=source_info,
+                                     allow_shared_blocks=allow_shared_blocks)
     hashtree_info_generator = verity_utils.CreateHashtreeInfoGenerator(
         "vendor", 4096, target_info)
-    vendor_tgt = common.GetSparseImage(
-        "vendor", OPTIONS.target_tmp, target_zip, allow_shared_blocks,
-        hashtree_info_generator)
+    vendor_tgt = common.GetUserImage(
+        "vendor", OPTIONS.target_tmp, target_zip,
+        info_dict=target_info,
+        allow_shared_blocks=allow_shared_blocks,
+        hashtree_info_generator=hashtree_info_generator)
 
     # Check first block of vendor partition for remount R/W only if
     # disk type is ext4
diff --git a/tools/releasetools/test_blockimgdiff.py b/tools/releasetools/test_blockimgdiff.py
index 1aabaa2..b6d47d4 100644
--- a/tools/releasetools/test_blockimgdiff.py
+++ b/tools/releasetools/test_blockimgdiff.py
@@ -14,9 +14,13 @@
 # limitations under the License.
 #
 
+import os
+from hashlib import sha1
+
 import common
 from blockimgdiff import (
-    BlockImageDiff, DataImage, EmptyImage, HeapItem, ImgdiffStats, Transfer)
+    BlockImageDiff, DataImage, EmptyImage, FileImage, HeapItem, ImgdiffStats,
+    Transfer)
 from rangelib import RangeSet
 from test_utils import ReleaseToolsTestCase
 
@@ -264,7 +268,42 @@
 
 
 class DataImageTest(ReleaseToolsTestCase):
-    def test_read_range_set(self):
-        data = "file" + ('\0' * 4092)
-        image = DataImage(data)
-        self.assertEqual(data, "".join(image.ReadRangeSet(image.care_map)))
+  def test_read_range_set(self):
+    data = "file" + ('\0' * 4092)
+    image = DataImage(data)
+    self.assertEqual(data, "".join(image.ReadRangeSet(image.care_map)))
+
+
+class FileImageTest(ReleaseToolsTestCase):
+  def setUp(self):
+    self.file_path = common.MakeTempFile()
+    self.data = os.urandom(4096 * 4)
+    with open(self.file_path, 'w') as f:
+      f.write(self.data)
+    self.file = FileImage(self.file_path)
+
+  def test_totalsha1(self):
+    self.assertEqual(sha1(self.data).hexdigest(), self.file.TotalSha1())
+
+  def test_ranges(self):
+    blocksize = self.file.blocksize
+    for s in range(4):
+      for e in range(s, 4):
+        expected_data = self.data[s * blocksize : e * blocksize]
+
+        rs = RangeSet([s, e])
+        data = "".join(self.file.ReadRangeSet(rs))
+        self.assertEqual(expected_data, data)
+
+        sha1sum = self.file.RangeSha1(rs)
+        self.assertEqual(sha1(expected_data).hexdigest(), sha1sum)
+
+        tmpfile = common.MakeTempFile()
+        with open(tmpfile, 'w') as f:
+          self.file.WriteRangeDataToFd(rs, f)
+        with open(tmpfile, 'r') as f:
+          self.assertEqual(expected_data, f.read())
+
+  def test_read_all(self):
+    data = "".join(self.file.ReadRangeSet(self.file.care_map))
+    self.assertEqual(self.data, data)