blob: 216486c00823915903acf019f80f0c948ce05e38 [file] [log] [blame]
Doug Zongker424296a2014-09-02 08:53:09 -07001# Copyright (C) 2014 The Android Open Source Project
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14
Doug Zongkerfc44a512014-08-26 13:10:25 -070015from __future__ import print_function
16
17from collections import deque, OrderedDict
18from hashlib import sha1
19import itertools
20import multiprocessing
21import os
22import pprint
23import re
24import subprocess
25import sys
26import threading
27import tempfile
28
29from rangelib import *
30
Doug Zongkerab7ca1d2014-08-26 10:40:28 -070031__all__ = ["EmptyImage", "DataImage", "BlockImageDiff"]
32
Doug Zongkerfc44a512014-08-26 13:10:25 -070033def compute_patch(src, tgt, imgdiff=False):
34 srcfd, srcfile = tempfile.mkstemp(prefix="src-")
35 tgtfd, tgtfile = tempfile.mkstemp(prefix="tgt-")
36 patchfd, patchfile = tempfile.mkstemp(prefix="patch-")
37 os.close(patchfd)
38
39 try:
40 with os.fdopen(srcfd, "wb") as f_src:
41 for p in src:
42 f_src.write(p)
43
44 with os.fdopen(tgtfd, "wb") as f_tgt:
45 for p in tgt:
46 f_tgt.write(p)
47 try:
48 os.unlink(patchfile)
49 except OSError:
50 pass
51 if imgdiff:
52 p = subprocess.call(["imgdiff", "-z", srcfile, tgtfile, patchfile],
53 stdout=open("/dev/null", "a"),
54 stderr=subprocess.STDOUT)
55 else:
56 p = subprocess.call(["bsdiff", srcfile, tgtfile, patchfile])
57
58 if p:
59 raise ValueError("diff failed: " + str(p))
60
61 with open(patchfile, "rb") as f:
62 return f.read()
63 finally:
64 try:
65 os.unlink(srcfile)
66 os.unlink(tgtfile)
67 os.unlink(patchfile)
68 except OSError:
69 pass
70
71class EmptyImage(object):
72 """A zero-length image."""
73 blocksize = 4096
74 care_map = RangeSet()
75 total_blocks = 0
76 file_map = {}
77 def ReadRangeSet(self, ranges):
78 return ()
Doug Zongkerab7ca1d2014-08-26 10:40:28 -070079 def TotalSha1(self):
80 return sha1().hexdigest()
81
82
83class DataImage(object):
84 """An image wrapped around a single string of data."""
85
86 def __init__(self, data, trim=False, pad=False):
87 self.data = data
88 self.blocksize = 4096
89
90 assert not (trim and pad)
91
92 partial = len(self.data) % self.blocksize
93 if partial > 0:
94 if trim:
95 self.data = self.data[:-partial]
96 elif pad:
97 self.data += '\0' * (self.blocksize - partial)
98 else:
99 raise ValueError(("data for DataImage must be multiple of %d bytes "
100 "unless trim or pad is specified") %
101 (self.blocksize,))
102
103 assert len(self.data) % self.blocksize == 0
104
105 self.total_blocks = len(self.data) / self.blocksize
106 self.care_map = RangeSet(data=(0, self.total_blocks))
107
108 zero_blocks = []
109 nonzero_blocks = []
110 reference = '\0' * self.blocksize
111
112 for i in range(self.total_blocks):
113 d = self.data[i*self.blocksize : (i+1)*self.blocksize]
114 if d == reference:
115 zero_blocks.append(i)
116 zero_blocks.append(i+1)
117 else:
118 nonzero_blocks.append(i)
119 nonzero_blocks.append(i+1)
120
121 self.file_map = {"__ZERO": RangeSet(zero_blocks),
122 "__NONZERO": RangeSet(nonzero_blocks)}
123
124 def ReadRangeSet(self, ranges):
125 return [self.data[s*self.blocksize:e*self.blocksize] for (s, e) in ranges]
126
127 def TotalSha1(self):
128 if not hasattr(self, "sha1"):
129 self.sha1 = sha1(self.data).hexdigest()
130 return self.sha1
131
Doug Zongkerfc44a512014-08-26 13:10:25 -0700132
133class Transfer(object):
134 def __init__(self, tgt_name, src_name, tgt_ranges, src_ranges, style, by_id):
135 self.tgt_name = tgt_name
136 self.src_name = src_name
137 self.tgt_ranges = tgt_ranges
138 self.src_ranges = src_ranges
139 self.style = style
140 self.intact = (getattr(tgt_ranges, "monotonic", False) and
141 getattr(src_ranges, "monotonic", False))
142 self.goes_before = {}
143 self.goes_after = {}
144
145 self.id = len(by_id)
146 by_id.append(self)
147
148 def __str__(self):
149 return (str(self.id) + ": <" + str(self.src_ranges) + " " + self.style +
150 " to " + str(self.tgt_ranges) + ">")
151
152
153# BlockImageDiff works on two image objects. An image object is
154# anything that provides the following attributes:
155#
156# blocksize: the size in bytes of a block, currently must be 4096.
157#
158# total_blocks: the total size of the partition/image, in blocks.
159#
160# care_map: a RangeSet containing which blocks (in the range [0,
161# total_blocks) we actually care about; i.e. which blocks contain
162# data.
163#
164# file_map: a dict that partitions the blocks contained in care_map
165# into smaller domains that are useful for doing diffs on.
166# (Typically a domain is a file, and the key in file_map is the
167# pathname.)
168#
169# ReadRangeSet(): a function that takes a RangeSet and returns the
170# data contained in the image blocks of that RangeSet. The data
171# is returned as a list or tuple of strings; concatenating the
172# elements together should produce the requested data.
173# Implementations are free to break up the data into list/tuple
174# elements in any way that is convenient.
175#
Doug Zongkerab7ca1d2014-08-26 10:40:28 -0700176# TotalSha1(): a function that returns (as a hex string) the SHA-1
177# hash of all the data in the image (ie, all the blocks in the
178# care_map)
179#
Doug Zongkerfc44a512014-08-26 13:10:25 -0700180# When creating a BlockImageDiff, the src image may be None, in which
181# case the list of transfers produced will never read from the
182# original image.
183
184class BlockImageDiff(object):
185 def __init__(self, tgt, src=None, threads=None):
186 if threads is None:
187 threads = multiprocessing.cpu_count() // 2
188 if threads == 0: threads = 1
189 self.threads = threads
190
191 self.tgt = tgt
192 if src is None:
193 src = EmptyImage()
194 self.src = src
195
196 # The updater code that installs the patch always uses 4k blocks.
197 assert tgt.blocksize == 4096
198 assert src.blocksize == 4096
199
200 # The range sets in each filemap should comprise a partition of
201 # the care map.
202 self.AssertPartition(src.care_map, src.file_map.values())
203 self.AssertPartition(tgt.care_map, tgt.file_map.values())
204
205 def Compute(self, prefix):
206 # When looking for a source file to use as the diff input for a
207 # target file, we try:
208 # 1) an exact path match if available, otherwise
209 # 2) a exact basename match if available, otherwise
210 # 3) a basename match after all runs of digits are replaced by
211 # "#" if available, otherwise
212 # 4) we have no source for this target.
213 self.AbbreviateSourceNames()
214 self.FindTransfers()
215
216 # Find the ordering dependencies among transfers (this is O(n^2)
217 # in the number of transfers).
218 self.GenerateDigraph()
219 # Find a sequence of transfers that satisfies as many ordering
220 # dependencies as possible (heuristically).
221 self.FindVertexSequence()
222 # Fix up the ordering dependencies that the sequence didn't
223 # satisfy.
224 self.RemoveBackwardEdges()
225 # Double-check our work.
226 self.AssertSequenceGood()
227
228 self.ComputePatches(prefix)
229 self.WriteTransfers(prefix)
230
231 def WriteTransfers(self, prefix):
232 out = []
233
234 out.append("1\n") # format version number
235 total = 0
236 performs_read = False
237
238 for xf in self.transfers:
239
240 # zero [rangeset]
241 # new [rangeset]
242 # bsdiff patchstart patchlen [src rangeset] [tgt rangeset]
243 # imgdiff patchstart patchlen [src rangeset] [tgt rangeset]
244 # move [src rangeset] [tgt rangeset]
245 # erase [rangeset]
246
247 tgt_size = xf.tgt_ranges.size()
248
249 if xf.style == "new":
250 assert xf.tgt_ranges
251 out.append("%s %s\n" % (xf.style, xf.tgt_ranges.to_string_raw()))
252 total += tgt_size
253 elif xf.style == "move":
254 performs_read = True
255 assert xf.tgt_ranges
256 assert xf.src_ranges.size() == tgt_size
257 if xf.src_ranges != xf.tgt_ranges:
258 out.append("%s %s %s\n" % (
259 xf.style,
260 xf.src_ranges.to_string_raw(), xf.tgt_ranges.to_string_raw()))
261 total += tgt_size
262 elif xf.style in ("bsdiff", "imgdiff"):
263 performs_read = True
264 assert xf.tgt_ranges
265 assert xf.src_ranges
266 out.append("%s %d %d %s %s\n" % (
267 xf.style, xf.patch_start, xf.patch_len,
268 xf.src_ranges.to_string_raw(), xf.tgt_ranges.to_string_raw()))
269 total += tgt_size
270 elif xf.style == "zero":
271 assert xf.tgt_ranges
272 to_zero = xf.tgt_ranges.subtract(xf.src_ranges)
273 if to_zero:
274 out.append("%s %s\n" % (xf.style, to_zero.to_string_raw()))
275 total += to_zero.size()
276 else:
277 raise ValueError, "unknown transfer style '%s'\n" % (xf.style,)
278
279 out.insert(1, str(total) + "\n")
280
281 all_tgt = RangeSet(data=(0, self.tgt.total_blocks))
282 if performs_read:
283 # if some of the original data is used, then at the end we'll
284 # erase all the blocks on the partition that don't contain data
285 # in the new image.
286 new_dontcare = all_tgt.subtract(self.tgt.care_map)
287 if new_dontcare:
288 out.append("erase %s\n" % (new_dontcare.to_string_raw(),))
289 else:
290 # if nothing is read (ie, this is a full OTA), then we can start
291 # by erasing the entire partition.
292 out.insert(2, "erase %s\n" % (all_tgt.to_string_raw(),))
293
294 with open(prefix + ".transfer.list", "wb") as f:
295 for i in out:
296 f.write(i)
297
298 def ComputePatches(self, prefix):
299 print("Reticulating splines...")
300 diff_q = []
301 patch_num = 0
302 with open(prefix + ".new.dat", "wb") as new_f:
303 for xf in self.transfers:
304 if xf.style == "zero":
305 pass
306 elif xf.style == "new":
307 for piece in self.tgt.ReadRangeSet(xf.tgt_ranges):
308 new_f.write(piece)
309 elif xf.style == "diff":
310 src = self.src.ReadRangeSet(xf.src_ranges)
311 tgt = self.tgt.ReadRangeSet(xf.tgt_ranges)
312
313 # We can't compare src and tgt directly because they may have
314 # the same content but be broken up into blocks differently, eg:
315 #
316 # ["he", "llo"] vs ["h", "ello"]
317 #
318 # We want those to compare equal, ideally without having to
319 # actually concatenate the strings (these may be tens of
320 # megabytes).
321
322 src_sha1 = sha1()
323 for p in src:
324 src_sha1.update(p)
325 tgt_sha1 = sha1()
326 tgt_size = 0
327 for p in tgt:
328 tgt_sha1.update(p)
329 tgt_size += len(p)
330
331 if src_sha1.digest() == tgt_sha1.digest():
332 # These are identical; we don't need to generate a patch,
333 # just issue copy commands on the device.
334 xf.style = "move"
335 else:
336 # For files in zip format (eg, APKs, JARs, etc.) we would
337 # like to use imgdiff -z if possible (because it usually
338 # produces significantly smaller patches than bsdiff).
339 # This is permissible if:
340 #
341 # - the source and target files are monotonic (ie, the
342 # data is stored with blocks in increasing order), and
343 # - we haven't removed any blocks from the source set.
344 #
345 # If these conditions are satisfied then appending all the
346 # blocks in the set together in order will produce a valid
347 # zip file (plus possibly extra zeros in the last block),
348 # which is what imgdiff needs to operate. (imgdiff is
349 # fine with extra zeros at the end of the file.)
350 imgdiff = (xf.intact and
351 xf.tgt_name.split(".")[-1].lower()
352 in ("apk", "jar", "zip"))
353 xf.style = "imgdiff" if imgdiff else "bsdiff"
354 diff_q.append((tgt_size, src, tgt, xf, patch_num))
355 patch_num += 1
356
357 else:
358 assert False, "unknown style " + xf.style
359
360 if diff_q:
361 if self.threads > 1:
362 print("Computing patches (using %d threads)..." % (self.threads,))
363 else:
364 print("Computing patches...")
365 diff_q.sort()
366
367 patches = [None] * patch_num
368
369 lock = threading.Lock()
370 def diff_worker():
371 while True:
372 with lock:
373 if not diff_q: return
374 tgt_size, src, tgt, xf, patchnum = diff_q.pop()
375 patch = compute_patch(src, tgt, imgdiff=(xf.style == "imgdiff"))
376 size = len(patch)
377 with lock:
378 patches[patchnum] = (patch, xf)
379 print("%10d %10d (%6.2f%%) %7s %s" % (
380 size, tgt_size, size * 100.0 / tgt_size, xf.style,
381 xf.tgt_name if xf.tgt_name == xf.src_name else (
382 xf.tgt_name + " (from " + xf.src_name + ")")))
383
384 threads = [threading.Thread(target=diff_worker)
385 for i in range(self.threads)]
386 for th in threads:
387 th.start()
388 while threads:
389 threads.pop().join()
390 else:
391 patches = []
392
393 p = 0
394 with open(prefix + ".patch.dat", "wb") as patch_f:
395 for patch, xf in patches:
396 xf.patch_start = p
397 xf.patch_len = len(patch)
398 patch_f.write(patch)
399 p += len(patch)
400
401 def AssertSequenceGood(self):
402 # Simulate the sequences of transfers we will output, and check that:
403 # - we never read a block after writing it, and
404 # - we write every block we care about exactly once.
405
406 # Start with no blocks having been touched yet.
407 touched = RangeSet()
408
409 # Imagine processing the transfers in order.
410 for xf in self.transfers:
411 # Check that the input blocks for this transfer haven't yet been touched.
412 assert not touched.overlaps(xf.src_ranges)
413 # Check that the output blocks for this transfer haven't yet been touched.
414 assert not touched.overlaps(xf.tgt_ranges)
415 # Touch all the blocks written by this transfer.
416 touched = touched.union(xf.tgt_ranges)
417
418 # Check that we've written every target block.
419 assert touched == self.tgt.care_map
420
421 def RemoveBackwardEdges(self):
422 print("Removing backward edges...")
423 in_order = 0
424 out_of_order = 0
425 lost_source = 0
426
427 for xf in self.transfers:
428 io = 0
429 ooo = 0
430 lost = 0
431 size = xf.src_ranges.size()
432 for u in xf.goes_before:
433 # xf should go before u
434 if xf.order < u.order:
435 # it does, hurray!
436 io += 1
437 else:
438 # it doesn't, boo. trim the blocks that u writes from xf's
439 # source, so that xf can go after u.
440 ooo += 1
441 assert xf.src_ranges.overlaps(u.tgt_ranges)
442 xf.src_ranges = xf.src_ranges.subtract(u.tgt_ranges)
443 xf.intact = False
444
445 if xf.style == "diff" and not xf.src_ranges:
446 # nothing left to diff from; treat as new data
447 xf.style = "new"
448
449 lost = size - xf.src_ranges.size()
450 lost_source += lost
451 in_order += io
452 out_of_order += ooo
453
454 print((" %d/%d dependencies (%.2f%%) were violated; "
455 "%d source blocks removed.") %
456 (out_of_order, in_order + out_of_order,
457 (out_of_order * 100.0 / (in_order + out_of_order))
458 if (in_order + out_of_order) else 0.0,
459 lost_source))
460
461 def FindVertexSequence(self):
462 print("Finding vertex sequence...")
463
464 # This is based on "A Fast & Effective Heuristic for the Feedback
465 # Arc Set Problem" by P. Eades, X. Lin, and W.F. Smyth. Think of
466 # it as starting with the digraph G and moving all the vertices to
467 # be on a horizontal line in some order, trying to minimize the
468 # number of edges that end up pointing to the left. Left-pointing
469 # edges will get removed to turn the digraph into a DAG. In this
470 # case each edge has a weight which is the number of source blocks
471 # we'll lose if that edge is removed; we try to minimize the total
472 # weight rather than just the number of edges.
473
474 # Make a copy of the edge set; this copy will get destroyed by the
475 # algorithm.
476 for xf in self.transfers:
477 xf.incoming = xf.goes_after.copy()
478 xf.outgoing = xf.goes_before.copy()
479
480 # We use an OrderedDict instead of just a set so that the output
481 # is repeatable; otherwise it would depend on the hash values of
482 # the transfer objects.
483 G = OrderedDict()
484 for xf in self.transfers:
485 G[xf] = None
486 s1 = deque() # the left side of the sequence, built from left to right
487 s2 = deque() # the right side of the sequence, built from right to left
488
489 while G:
490
491 # Put all sinks at the end of the sequence.
492 while True:
493 sinks = [u for u in G if not u.outgoing]
494 if not sinks: break
495 for u in sinks:
496 s2.appendleft(u)
497 del G[u]
498 for iu in u.incoming:
499 del iu.outgoing[u]
500
501 # Put all the sources at the beginning of the sequence.
502 while True:
503 sources = [u for u in G if not u.incoming]
504 if not sources: break
505 for u in sources:
506 s1.append(u)
507 del G[u]
508 for iu in u.outgoing:
509 del iu.incoming[u]
510
511 if not G: break
512
513 # Find the "best" vertex to put next. "Best" is the one that
514 # maximizes the net difference in source blocks saved we get by
515 # pretending it's a source rather than a sink.
516
517 max_d = None
518 best_u = None
519 for u in G:
520 d = sum(u.outgoing.values()) - sum(u.incoming.values())
521 if best_u is None or d > max_d:
522 max_d = d
523 best_u = u
524
525 u = best_u
526 s1.append(u)
527 del G[u]
528 for iu in u.outgoing:
529 del iu.incoming[u]
530 for iu in u.incoming:
531 del iu.outgoing[u]
532
533 # Now record the sequence in the 'order' field of each transfer,
534 # and by rearranging self.transfers to be in the chosen sequence.
535
536 new_transfers = []
537 for x in itertools.chain(s1, s2):
538 x.order = len(new_transfers)
539 new_transfers.append(x)
540 del x.incoming
541 del x.outgoing
542
543 self.transfers = new_transfers
544
545 def GenerateDigraph(self):
546 print("Generating digraph...")
547 for a in self.transfers:
548 for b in self.transfers:
549 if a is b: continue
550
551 # If the blocks written by A are read by B, then B needs to go before A.
552 i = a.tgt_ranges.intersect(b.src_ranges)
553 if i:
Doug Zongkerab7ca1d2014-08-26 10:40:28 -0700554 if b.src_name == "__ZERO":
555 # the cost of removing source blocks for the __ZERO domain
556 # is (nearly) zero.
557 size = 0
558 else:
559 size = i.size()
Doug Zongkerfc44a512014-08-26 13:10:25 -0700560 b.goes_before[a] = size
561 a.goes_after[b] = size
562
563 def FindTransfers(self):
564 self.transfers = []
565 empty = RangeSet()
566 for tgt_fn, tgt_ranges in self.tgt.file_map.items():
567 if tgt_fn == "__ZERO":
568 # the special "__ZERO" domain is all the blocks not contained
569 # in any file and that are filled with zeros. We have a
570 # special transfer style for zero blocks.
571 src_ranges = self.src.file_map.get("__ZERO", empty)
Doug Zongkerab7ca1d2014-08-26 10:40:28 -0700572 Transfer(tgt_fn, "__ZERO", tgt_ranges, src_ranges,
573 "zero", self.transfers)
Doug Zongkerfc44a512014-08-26 13:10:25 -0700574 continue
575
576 elif tgt_fn in self.src.file_map:
577 # Look for an exact pathname match in the source.
578 Transfer(tgt_fn, tgt_fn, tgt_ranges, self.src.file_map[tgt_fn],
579 "diff", self.transfers)
580 continue
581
582 b = os.path.basename(tgt_fn)
583 if b in self.src_basenames:
584 # Look for an exact basename match in the source.
585 src_fn = self.src_basenames[b]
586 Transfer(tgt_fn, src_fn, tgt_ranges, self.src.file_map[src_fn],
587 "diff", self.transfers)
588 continue
589
590 b = re.sub("[0-9]+", "#", b)
591 if b in self.src_numpatterns:
592 # Look for a 'number pattern' match (a basename match after
593 # all runs of digits are replaced by "#"). (This is useful
594 # for .so files that contain version numbers in the filename
595 # that get bumped.)
596 src_fn = self.src_numpatterns[b]
597 Transfer(tgt_fn, src_fn, tgt_ranges, self.src.file_map[src_fn],
598 "diff", self.transfers)
599 continue
600
601 Transfer(tgt_fn, None, tgt_ranges, empty, "new", self.transfers)
602
603 def AbbreviateSourceNames(self):
604 self.src_basenames = {}
605 self.src_numpatterns = {}
606
607 for k in self.src.file_map.keys():
608 b = os.path.basename(k)
609 self.src_basenames[b] = k
610 b = re.sub("[0-9]+", "#", b)
611 self.src_numpatterns[b] = k
612
613 @staticmethod
614 def AssertPartition(total, seq):
615 """Assert that all the RangeSets in 'seq' form a partition of the
616 'total' RangeSet (ie, they are nonintersecting and their union
617 equals 'total')."""
618 so_far = RangeSet()
619 for i in seq:
620 assert not so_far.overlaps(i)
621 so_far = so_far.union(i)
622 assert so_far == total