blob: 6f94edab58249902268588ad71afea2ec5bf0ebb [file] [log] [blame]
Doug Zongkerfc44a512014-08-26 13:10:25 -07001from __future__ import print_function
2
3from collections import deque, OrderedDict
4from hashlib import sha1
5import itertools
6import multiprocessing
7import os
8import pprint
9import re
10import subprocess
11import sys
12import threading
13import tempfile
14
15from rangelib import *
16
Doug Zongkerab7ca1d2014-08-26 10:40:28 -070017__all__ = ["EmptyImage", "DataImage", "BlockImageDiff"]
18
Doug Zongkerfc44a512014-08-26 13:10:25 -070019def compute_patch(src, tgt, imgdiff=False):
20 srcfd, srcfile = tempfile.mkstemp(prefix="src-")
21 tgtfd, tgtfile = tempfile.mkstemp(prefix="tgt-")
22 patchfd, patchfile = tempfile.mkstemp(prefix="patch-")
23 os.close(patchfd)
24
25 try:
26 with os.fdopen(srcfd, "wb") as f_src:
27 for p in src:
28 f_src.write(p)
29
30 with os.fdopen(tgtfd, "wb") as f_tgt:
31 for p in tgt:
32 f_tgt.write(p)
33 try:
34 os.unlink(patchfile)
35 except OSError:
36 pass
37 if imgdiff:
38 p = subprocess.call(["imgdiff", "-z", srcfile, tgtfile, patchfile],
39 stdout=open("/dev/null", "a"),
40 stderr=subprocess.STDOUT)
41 else:
42 p = subprocess.call(["bsdiff", srcfile, tgtfile, patchfile])
43
44 if p:
45 raise ValueError("diff failed: " + str(p))
46
47 with open(patchfile, "rb") as f:
48 return f.read()
49 finally:
50 try:
51 os.unlink(srcfile)
52 os.unlink(tgtfile)
53 os.unlink(patchfile)
54 except OSError:
55 pass
56
57class EmptyImage(object):
58 """A zero-length image."""
59 blocksize = 4096
60 care_map = RangeSet()
61 total_blocks = 0
62 file_map = {}
63 def ReadRangeSet(self, ranges):
64 return ()
Doug Zongkerab7ca1d2014-08-26 10:40:28 -070065 def TotalSha1(self):
66 return sha1().hexdigest()
67
68
69class DataImage(object):
70 """An image wrapped around a single string of data."""
71
72 def __init__(self, data, trim=False, pad=False):
73 self.data = data
74 self.blocksize = 4096
75
76 assert not (trim and pad)
77
78 partial = len(self.data) % self.blocksize
79 if partial > 0:
80 if trim:
81 self.data = self.data[:-partial]
82 elif pad:
83 self.data += '\0' * (self.blocksize - partial)
84 else:
85 raise ValueError(("data for DataImage must be multiple of %d bytes "
86 "unless trim or pad is specified") %
87 (self.blocksize,))
88
89 assert len(self.data) % self.blocksize == 0
90
91 self.total_blocks = len(self.data) / self.blocksize
92 self.care_map = RangeSet(data=(0, self.total_blocks))
93
94 zero_blocks = []
95 nonzero_blocks = []
96 reference = '\0' * self.blocksize
97
98 for i in range(self.total_blocks):
99 d = self.data[i*self.blocksize : (i+1)*self.blocksize]
100 if d == reference:
101 zero_blocks.append(i)
102 zero_blocks.append(i+1)
103 else:
104 nonzero_blocks.append(i)
105 nonzero_blocks.append(i+1)
106
107 self.file_map = {"__ZERO": RangeSet(zero_blocks),
108 "__NONZERO": RangeSet(nonzero_blocks)}
109
110 def ReadRangeSet(self, ranges):
111 return [self.data[s*self.blocksize:e*self.blocksize] for (s, e) in ranges]
112
113 def TotalSha1(self):
114 if not hasattr(self, "sha1"):
115 self.sha1 = sha1(self.data).hexdigest()
116 return self.sha1
117
Doug Zongkerfc44a512014-08-26 13:10:25 -0700118
119class Transfer(object):
120 def __init__(self, tgt_name, src_name, tgt_ranges, src_ranges, style, by_id):
121 self.tgt_name = tgt_name
122 self.src_name = src_name
123 self.tgt_ranges = tgt_ranges
124 self.src_ranges = src_ranges
125 self.style = style
126 self.intact = (getattr(tgt_ranges, "monotonic", False) and
127 getattr(src_ranges, "monotonic", False))
128 self.goes_before = {}
129 self.goes_after = {}
130
131 self.id = len(by_id)
132 by_id.append(self)
133
134 def __str__(self):
135 return (str(self.id) + ": <" + str(self.src_ranges) + " " + self.style +
136 " to " + str(self.tgt_ranges) + ">")
137
138
139# BlockImageDiff works on two image objects. An image object is
140# anything that provides the following attributes:
141#
142# blocksize: the size in bytes of a block, currently must be 4096.
143#
144# total_blocks: the total size of the partition/image, in blocks.
145#
146# care_map: a RangeSet containing which blocks (in the range [0,
147# total_blocks) we actually care about; i.e. which blocks contain
148# data.
149#
150# file_map: a dict that partitions the blocks contained in care_map
151# into smaller domains that are useful for doing diffs on.
152# (Typically a domain is a file, and the key in file_map is the
153# pathname.)
154#
155# ReadRangeSet(): a function that takes a RangeSet and returns the
156# data contained in the image blocks of that RangeSet. The data
157# is returned as a list or tuple of strings; concatenating the
158# elements together should produce the requested data.
159# Implementations are free to break up the data into list/tuple
160# elements in any way that is convenient.
161#
Doug Zongkerab7ca1d2014-08-26 10:40:28 -0700162# TotalSha1(): a function that returns (as a hex string) the SHA-1
163# hash of all the data in the image (ie, all the blocks in the
164# care_map)
165#
Doug Zongkerfc44a512014-08-26 13:10:25 -0700166# When creating a BlockImageDiff, the src image may be None, in which
167# case the list of transfers produced will never read from the
168# original image.
169
170class BlockImageDiff(object):
171 def __init__(self, tgt, src=None, threads=None):
172 if threads is None:
173 threads = multiprocessing.cpu_count() // 2
174 if threads == 0: threads = 1
175 self.threads = threads
176
177 self.tgt = tgt
178 if src is None:
179 src = EmptyImage()
180 self.src = src
181
182 # The updater code that installs the patch always uses 4k blocks.
183 assert tgt.blocksize == 4096
184 assert src.blocksize == 4096
185
186 # The range sets in each filemap should comprise a partition of
187 # the care map.
188 self.AssertPartition(src.care_map, src.file_map.values())
189 self.AssertPartition(tgt.care_map, tgt.file_map.values())
190
191 def Compute(self, prefix):
192 # When looking for a source file to use as the diff input for a
193 # target file, we try:
194 # 1) an exact path match if available, otherwise
195 # 2) a exact basename match if available, otherwise
196 # 3) a basename match after all runs of digits are replaced by
197 # "#" if available, otherwise
198 # 4) we have no source for this target.
199 self.AbbreviateSourceNames()
200 self.FindTransfers()
201
202 # Find the ordering dependencies among transfers (this is O(n^2)
203 # in the number of transfers).
204 self.GenerateDigraph()
205 # Find a sequence of transfers that satisfies as many ordering
206 # dependencies as possible (heuristically).
207 self.FindVertexSequence()
208 # Fix up the ordering dependencies that the sequence didn't
209 # satisfy.
210 self.RemoveBackwardEdges()
211 # Double-check our work.
212 self.AssertSequenceGood()
213
214 self.ComputePatches(prefix)
215 self.WriteTransfers(prefix)
216
217 def WriteTransfers(self, prefix):
218 out = []
219
220 out.append("1\n") # format version number
221 total = 0
222 performs_read = False
223
224 for xf in self.transfers:
225
226 # zero [rangeset]
227 # new [rangeset]
228 # bsdiff patchstart patchlen [src rangeset] [tgt rangeset]
229 # imgdiff patchstart patchlen [src rangeset] [tgt rangeset]
230 # move [src rangeset] [tgt rangeset]
231 # erase [rangeset]
232
233 tgt_size = xf.tgt_ranges.size()
234
235 if xf.style == "new":
236 assert xf.tgt_ranges
237 out.append("%s %s\n" % (xf.style, xf.tgt_ranges.to_string_raw()))
238 total += tgt_size
239 elif xf.style == "move":
240 performs_read = True
241 assert xf.tgt_ranges
242 assert xf.src_ranges.size() == tgt_size
243 if xf.src_ranges != xf.tgt_ranges:
244 out.append("%s %s %s\n" % (
245 xf.style,
246 xf.src_ranges.to_string_raw(), xf.tgt_ranges.to_string_raw()))
247 total += tgt_size
248 elif xf.style in ("bsdiff", "imgdiff"):
249 performs_read = True
250 assert xf.tgt_ranges
251 assert xf.src_ranges
252 out.append("%s %d %d %s %s\n" % (
253 xf.style, xf.patch_start, xf.patch_len,
254 xf.src_ranges.to_string_raw(), xf.tgt_ranges.to_string_raw()))
255 total += tgt_size
256 elif xf.style == "zero":
257 assert xf.tgt_ranges
258 to_zero = xf.tgt_ranges.subtract(xf.src_ranges)
259 if to_zero:
260 out.append("%s %s\n" % (xf.style, to_zero.to_string_raw()))
261 total += to_zero.size()
262 else:
263 raise ValueError, "unknown transfer style '%s'\n" % (xf.style,)
264
265 out.insert(1, str(total) + "\n")
266
267 all_tgt = RangeSet(data=(0, self.tgt.total_blocks))
268 if performs_read:
269 # if some of the original data is used, then at the end we'll
270 # erase all the blocks on the partition that don't contain data
271 # in the new image.
272 new_dontcare = all_tgt.subtract(self.tgt.care_map)
273 if new_dontcare:
274 out.append("erase %s\n" % (new_dontcare.to_string_raw(),))
275 else:
276 # if nothing is read (ie, this is a full OTA), then we can start
277 # by erasing the entire partition.
278 out.insert(2, "erase %s\n" % (all_tgt.to_string_raw(),))
279
280 with open(prefix + ".transfer.list", "wb") as f:
281 for i in out:
282 f.write(i)
283
284 def ComputePatches(self, prefix):
285 print("Reticulating splines...")
286 diff_q = []
287 patch_num = 0
288 with open(prefix + ".new.dat", "wb") as new_f:
289 for xf in self.transfers:
290 if xf.style == "zero":
291 pass
292 elif xf.style == "new":
293 for piece in self.tgt.ReadRangeSet(xf.tgt_ranges):
294 new_f.write(piece)
295 elif xf.style == "diff":
296 src = self.src.ReadRangeSet(xf.src_ranges)
297 tgt = self.tgt.ReadRangeSet(xf.tgt_ranges)
298
299 # We can't compare src and tgt directly because they may have
300 # the same content but be broken up into blocks differently, eg:
301 #
302 # ["he", "llo"] vs ["h", "ello"]
303 #
304 # We want those to compare equal, ideally without having to
305 # actually concatenate the strings (these may be tens of
306 # megabytes).
307
308 src_sha1 = sha1()
309 for p in src:
310 src_sha1.update(p)
311 tgt_sha1 = sha1()
312 tgt_size = 0
313 for p in tgt:
314 tgt_sha1.update(p)
315 tgt_size += len(p)
316
317 if src_sha1.digest() == tgt_sha1.digest():
318 # These are identical; we don't need to generate a patch,
319 # just issue copy commands on the device.
320 xf.style = "move"
321 else:
322 # For files in zip format (eg, APKs, JARs, etc.) we would
323 # like to use imgdiff -z if possible (because it usually
324 # produces significantly smaller patches than bsdiff).
325 # This is permissible if:
326 #
327 # - the source and target files are monotonic (ie, the
328 # data is stored with blocks in increasing order), and
329 # - we haven't removed any blocks from the source set.
330 #
331 # If these conditions are satisfied then appending all the
332 # blocks in the set together in order will produce a valid
333 # zip file (plus possibly extra zeros in the last block),
334 # which is what imgdiff needs to operate. (imgdiff is
335 # fine with extra zeros at the end of the file.)
336 imgdiff = (xf.intact and
337 xf.tgt_name.split(".")[-1].lower()
338 in ("apk", "jar", "zip"))
339 xf.style = "imgdiff" if imgdiff else "bsdiff"
340 diff_q.append((tgt_size, src, tgt, xf, patch_num))
341 patch_num += 1
342
343 else:
344 assert False, "unknown style " + xf.style
345
346 if diff_q:
347 if self.threads > 1:
348 print("Computing patches (using %d threads)..." % (self.threads,))
349 else:
350 print("Computing patches...")
351 diff_q.sort()
352
353 patches = [None] * patch_num
354
355 lock = threading.Lock()
356 def diff_worker():
357 while True:
358 with lock:
359 if not diff_q: return
360 tgt_size, src, tgt, xf, patchnum = diff_q.pop()
361 patch = compute_patch(src, tgt, imgdiff=(xf.style == "imgdiff"))
362 size = len(patch)
363 with lock:
364 patches[patchnum] = (patch, xf)
365 print("%10d %10d (%6.2f%%) %7s %s" % (
366 size, tgt_size, size * 100.0 / tgt_size, xf.style,
367 xf.tgt_name if xf.tgt_name == xf.src_name else (
368 xf.tgt_name + " (from " + xf.src_name + ")")))
369
370 threads = [threading.Thread(target=diff_worker)
371 for i in range(self.threads)]
372 for th in threads:
373 th.start()
374 while threads:
375 threads.pop().join()
376 else:
377 patches = []
378
379 p = 0
380 with open(prefix + ".patch.dat", "wb") as patch_f:
381 for patch, xf in patches:
382 xf.patch_start = p
383 xf.patch_len = len(patch)
384 patch_f.write(patch)
385 p += len(patch)
386
387 def AssertSequenceGood(self):
388 # Simulate the sequences of transfers we will output, and check that:
389 # - we never read a block after writing it, and
390 # - we write every block we care about exactly once.
391
392 # Start with no blocks having been touched yet.
393 touched = RangeSet()
394
395 # Imagine processing the transfers in order.
396 for xf in self.transfers:
397 # Check that the input blocks for this transfer haven't yet been touched.
398 assert not touched.overlaps(xf.src_ranges)
399 # Check that the output blocks for this transfer haven't yet been touched.
400 assert not touched.overlaps(xf.tgt_ranges)
401 # Touch all the blocks written by this transfer.
402 touched = touched.union(xf.tgt_ranges)
403
404 # Check that we've written every target block.
405 assert touched == self.tgt.care_map
406
407 def RemoveBackwardEdges(self):
408 print("Removing backward edges...")
409 in_order = 0
410 out_of_order = 0
411 lost_source = 0
412
413 for xf in self.transfers:
414 io = 0
415 ooo = 0
416 lost = 0
417 size = xf.src_ranges.size()
418 for u in xf.goes_before:
419 # xf should go before u
420 if xf.order < u.order:
421 # it does, hurray!
422 io += 1
423 else:
424 # it doesn't, boo. trim the blocks that u writes from xf's
425 # source, so that xf can go after u.
426 ooo += 1
427 assert xf.src_ranges.overlaps(u.tgt_ranges)
428 xf.src_ranges = xf.src_ranges.subtract(u.tgt_ranges)
429 xf.intact = False
430
431 if xf.style == "diff" and not xf.src_ranges:
432 # nothing left to diff from; treat as new data
433 xf.style = "new"
434
435 lost = size - xf.src_ranges.size()
436 lost_source += lost
437 in_order += io
438 out_of_order += ooo
439
440 print((" %d/%d dependencies (%.2f%%) were violated; "
441 "%d source blocks removed.") %
442 (out_of_order, in_order + out_of_order,
443 (out_of_order * 100.0 / (in_order + out_of_order))
444 if (in_order + out_of_order) else 0.0,
445 lost_source))
446
447 def FindVertexSequence(self):
448 print("Finding vertex sequence...")
449
450 # This is based on "A Fast & Effective Heuristic for the Feedback
451 # Arc Set Problem" by P. Eades, X. Lin, and W.F. Smyth. Think of
452 # it as starting with the digraph G and moving all the vertices to
453 # be on a horizontal line in some order, trying to minimize the
454 # number of edges that end up pointing to the left. Left-pointing
455 # edges will get removed to turn the digraph into a DAG. In this
456 # case each edge has a weight which is the number of source blocks
457 # we'll lose if that edge is removed; we try to minimize the total
458 # weight rather than just the number of edges.
459
460 # Make a copy of the edge set; this copy will get destroyed by the
461 # algorithm.
462 for xf in self.transfers:
463 xf.incoming = xf.goes_after.copy()
464 xf.outgoing = xf.goes_before.copy()
465
466 # We use an OrderedDict instead of just a set so that the output
467 # is repeatable; otherwise it would depend on the hash values of
468 # the transfer objects.
469 G = OrderedDict()
470 for xf in self.transfers:
471 G[xf] = None
472 s1 = deque() # the left side of the sequence, built from left to right
473 s2 = deque() # the right side of the sequence, built from right to left
474
475 while G:
476
477 # Put all sinks at the end of the sequence.
478 while True:
479 sinks = [u for u in G if not u.outgoing]
480 if not sinks: break
481 for u in sinks:
482 s2.appendleft(u)
483 del G[u]
484 for iu in u.incoming:
485 del iu.outgoing[u]
486
487 # Put all the sources at the beginning of the sequence.
488 while True:
489 sources = [u for u in G if not u.incoming]
490 if not sources: break
491 for u in sources:
492 s1.append(u)
493 del G[u]
494 for iu in u.outgoing:
495 del iu.incoming[u]
496
497 if not G: break
498
499 # Find the "best" vertex to put next. "Best" is the one that
500 # maximizes the net difference in source blocks saved we get by
501 # pretending it's a source rather than a sink.
502
503 max_d = None
504 best_u = None
505 for u in G:
506 d = sum(u.outgoing.values()) - sum(u.incoming.values())
507 if best_u is None or d > max_d:
508 max_d = d
509 best_u = u
510
511 u = best_u
512 s1.append(u)
513 del G[u]
514 for iu in u.outgoing:
515 del iu.incoming[u]
516 for iu in u.incoming:
517 del iu.outgoing[u]
518
519 # Now record the sequence in the 'order' field of each transfer,
520 # and by rearranging self.transfers to be in the chosen sequence.
521
522 new_transfers = []
523 for x in itertools.chain(s1, s2):
524 x.order = len(new_transfers)
525 new_transfers.append(x)
526 del x.incoming
527 del x.outgoing
528
529 self.transfers = new_transfers
530
531 def GenerateDigraph(self):
532 print("Generating digraph...")
533 for a in self.transfers:
534 for b in self.transfers:
535 if a is b: continue
536
537 # If the blocks written by A are read by B, then B needs to go before A.
538 i = a.tgt_ranges.intersect(b.src_ranges)
539 if i:
Doug Zongkerab7ca1d2014-08-26 10:40:28 -0700540 if b.src_name == "__ZERO":
541 # the cost of removing source blocks for the __ZERO domain
542 # is (nearly) zero.
543 size = 0
544 else:
545 size = i.size()
Doug Zongkerfc44a512014-08-26 13:10:25 -0700546 b.goes_before[a] = size
547 a.goes_after[b] = size
548
549 def FindTransfers(self):
550 self.transfers = []
551 empty = RangeSet()
552 for tgt_fn, tgt_ranges in self.tgt.file_map.items():
553 if tgt_fn == "__ZERO":
554 # the special "__ZERO" domain is all the blocks not contained
555 # in any file and that are filled with zeros. We have a
556 # special transfer style for zero blocks.
557 src_ranges = self.src.file_map.get("__ZERO", empty)
Doug Zongkerab7ca1d2014-08-26 10:40:28 -0700558 Transfer(tgt_fn, "__ZERO", tgt_ranges, src_ranges,
559 "zero", self.transfers)
Doug Zongkerfc44a512014-08-26 13:10:25 -0700560 continue
561
562 elif tgt_fn in self.src.file_map:
563 # Look for an exact pathname match in the source.
564 Transfer(tgt_fn, tgt_fn, tgt_ranges, self.src.file_map[tgt_fn],
565 "diff", self.transfers)
566 continue
567
568 b = os.path.basename(tgt_fn)
569 if b in self.src_basenames:
570 # Look for an exact basename match in the source.
571 src_fn = self.src_basenames[b]
572 Transfer(tgt_fn, src_fn, tgt_ranges, self.src.file_map[src_fn],
573 "diff", self.transfers)
574 continue
575
576 b = re.sub("[0-9]+", "#", b)
577 if b in self.src_numpatterns:
578 # Look for a 'number pattern' match (a basename match after
579 # all runs of digits are replaced by "#"). (This is useful
580 # for .so files that contain version numbers in the filename
581 # that get bumped.)
582 src_fn = self.src_numpatterns[b]
583 Transfer(tgt_fn, src_fn, tgt_ranges, self.src.file_map[src_fn],
584 "diff", self.transfers)
585 continue
586
587 Transfer(tgt_fn, None, tgt_ranges, empty, "new", self.transfers)
588
589 def AbbreviateSourceNames(self):
590 self.src_basenames = {}
591 self.src_numpatterns = {}
592
593 for k in self.src.file_map.keys():
594 b = os.path.basename(k)
595 self.src_basenames[b] = k
596 b = re.sub("[0-9]+", "#", b)
597 self.src_numpatterns[b] = k
598
599 @staticmethod
600 def AssertPartition(total, seq):
601 """Assert that all the RangeSets in 'seq' form a partition of the
602 'total' RangeSet (ie, they are nonintersecting and their union
603 equals 'total')."""
604 so_far = RangeSet()
605 for i in seq:
606 assert not so_far.overlaps(i)
607 so_far = so_far.union(i)
608 assert so_far == total