blob: 7b01f711de09d1459a686058ea50b975d0d07954 [file] [log] [blame]
Doug Zongkerfc44a512014-08-26 13:10:25 -07001from __future__ import print_function
2
3from collections import deque, OrderedDict
4from hashlib import sha1
5import itertools
6import multiprocessing
7import os
8import pprint
9import re
10import subprocess
11import sys
12import threading
13import tempfile
14
15from rangelib import *
16
17def compute_patch(src, tgt, imgdiff=False):
18 srcfd, srcfile = tempfile.mkstemp(prefix="src-")
19 tgtfd, tgtfile = tempfile.mkstemp(prefix="tgt-")
20 patchfd, patchfile = tempfile.mkstemp(prefix="patch-")
21 os.close(patchfd)
22
23 try:
24 with os.fdopen(srcfd, "wb") as f_src:
25 for p in src:
26 f_src.write(p)
27
28 with os.fdopen(tgtfd, "wb") as f_tgt:
29 for p in tgt:
30 f_tgt.write(p)
31 try:
32 os.unlink(patchfile)
33 except OSError:
34 pass
35 if imgdiff:
36 p = subprocess.call(["imgdiff", "-z", srcfile, tgtfile, patchfile],
37 stdout=open("/dev/null", "a"),
38 stderr=subprocess.STDOUT)
39 else:
40 p = subprocess.call(["bsdiff", srcfile, tgtfile, patchfile])
41
42 if p:
43 raise ValueError("diff failed: " + str(p))
44
45 with open(patchfile, "rb") as f:
46 return f.read()
47 finally:
48 try:
49 os.unlink(srcfile)
50 os.unlink(tgtfile)
51 os.unlink(patchfile)
52 except OSError:
53 pass
54
55class EmptyImage(object):
56 """A zero-length image."""
57 blocksize = 4096
58 care_map = RangeSet()
59 total_blocks = 0
60 file_map = {}
61 def ReadRangeSet(self, ranges):
62 return ()
63
64class Transfer(object):
65 def __init__(self, tgt_name, src_name, tgt_ranges, src_ranges, style, by_id):
66 self.tgt_name = tgt_name
67 self.src_name = src_name
68 self.tgt_ranges = tgt_ranges
69 self.src_ranges = src_ranges
70 self.style = style
71 self.intact = (getattr(tgt_ranges, "monotonic", False) and
72 getattr(src_ranges, "monotonic", False))
73 self.goes_before = {}
74 self.goes_after = {}
75
76 self.id = len(by_id)
77 by_id.append(self)
78
79 def __str__(self):
80 return (str(self.id) + ": <" + str(self.src_ranges) + " " + self.style +
81 " to " + str(self.tgt_ranges) + ">")
82
83
84# BlockImageDiff works on two image objects. An image object is
85# anything that provides the following attributes:
86#
87# blocksize: the size in bytes of a block, currently must be 4096.
88#
89# total_blocks: the total size of the partition/image, in blocks.
90#
91# care_map: a RangeSet containing which blocks (in the range [0,
92# total_blocks) we actually care about; i.e. which blocks contain
93# data.
94#
95# file_map: a dict that partitions the blocks contained in care_map
96# into smaller domains that are useful for doing diffs on.
97# (Typically a domain is a file, and the key in file_map is the
98# pathname.)
99#
100# ReadRangeSet(): a function that takes a RangeSet and returns the
101# data contained in the image blocks of that RangeSet. The data
102# is returned as a list or tuple of strings; concatenating the
103# elements together should produce the requested data.
104# Implementations are free to break up the data into list/tuple
105# elements in any way that is convenient.
106#
107# When creating a BlockImageDiff, the src image may be None, in which
108# case the list of transfers produced will never read from the
109# original image.
110
111class BlockImageDiff(object):
112 def __init__(self, tgt, src=None, threads=None):
113 if threads is None:
114 threads = multiprocessing.cpu_count() // 2
115 if threads == 0: threads = 1
116 self.threads = threads
117
118 self.tgt = tgt
119 if src is None:
120 src = EmptyImage()
121 self.src = src
122
123 # The updater code that installs the patch always uses 4k blocks.
124 assert tgt.blocksize == 4096
125 assert src.blocksize == 4096
126
127 # The range sets in each filemap should comprise a partition of
128 # the care map.
129 self.AssertPartition(src.care_map, src.file_map.values())
130 self.AssertPartition(tgt.care_map, tgt.file_map.values())
131
132 def Compute(self, prefix):
133 # When looking for a source file to use as the diff input for a
134 # target file, we try:
135 # 1) an exact path match if available, otherwise
136 # 2) a exact basename match if available, otherwise
137 # 3) a basename match after all runs of digits are replaced by
138 # "#" if available, otherwise
139 # 4) we have no source for this target.
140 self.AbbreviateSourceNames()
141 self.FindTransfers()
142
143 # Find the ordering dependencies among transfers (this is O(n^2)
144 # in the number of transfers).
145 self.GenerateDigraph()
146 # Find a sequence of transfers that satisfies as many ordering
147 # dependencies as possible (heuristically).
148 self.FindVertexSequence()
149 # Fix up the ordering dependencies that the sequence didn't
150 # satisfy.
151 self.RemoveBackwardEdges()
152 # Double-check our work.
153 self.AssertSequenceGood()
154
155 self.ComputePatches(prefix)
156 self.WriteTransfers(prefix)
157
158 def WriteTransfers(self, prefix):
159 out = []
160
161 out.append("1\n") # format version number
162 total = 0
163 performs_read = False
164
165 for xf in self.transfers:
166
167 # zero [rangeset]
168 # new [rangeset]
169 # bsdiff patchstart patchlen [src rangeset] [tgt rangeset]
170 # imgdiff patchstart patchlen [src rangeset] [tgt rangeset]
171 # move [src rangeset] [tgt rangeset]
172 # erase [rangeset]
173
174 tgt_size = xf.tgt_ranges.size()
175
176 if xf.style == "new":
177 assert xf.tgt_ranges
178 out.append("%s %s\n" % (xf.style, xf.tgt_ranges.to_string_raw()))
179 total += tgt_size
180 elif xf.style == "move":
181 performs_read = True
182 assert xf.tgt_ranges
183 assert xf.src_ranges.size() == tgt_size
184 if xf.src_ranges != xf.tgt_ranges:
185 out.append("%s %s %s\n" % (
186 xf.style,
187 xf.src_ranges.to_string_raw(), xf.tgt_ranges.to_string_raw()))
188 total += tgt_size
189 elif xf.style in ("bsdiff", "imgdiff"):
190 performs_read = True
191 assert xf.tgt_ranges
192 assert xf.src_ranges
193 out.append("%s %d %d %s %s\n" % (
194 xf.style, xf.patch_start, xf.patch_len,
195 xf.src_ranges.to_string_raw(), xf.tgt_ranges.to_string_raw()))
196 total += tgt_size
197 elif xf.style == "zero":
198 assert xf.tgt_ranges
199 to_zero = xf.tgt_ranges.subtract(xf.src_ranges)
200 if to_zero:
201 out.append("%s %s\n" % (xf.style, to_zero.to_string_raw()))
202 total += to_zero.size()
203 else:
204 raise ValueError, "unknown transfer style '%s'\n" % (xf.style,)
205
206 out.insert(1, str(total) + "\n")
207
208 all_tgt = RangeSet(data=(0, self.tgt.total_blocks))
209 if performs_read:
210 # if some of the original data is used, then at the end we'll
211 # erase all the blocks on the partition that don't contain data
212 # in the new image.
213 new_dontcare = all_tgt.subtract(self.tgt.care_map)
214 if new_dontcare:
215 out.append("erase %s\n" % (new_dontcare.to_string_raw(),))
216 else:
217 # if nothing is read (ie, this is a full OTA), then we can start
218 # by erasing the entire partition.
219 out.insert(2, "erase %s\n" % (all_tgt.to_string_raw(),))
220
221 with open(prefix + ".transfer.list", "wb") as f:
222 for i in out:
223 f.write(i)
224
225 def ComputePatches(self, prefix):
226 print("Reticulating splines...")
227 diff_q = []
228 patch_num = 0
229 with open(prefix + ".new.dat", "wb") as new_f:
230 for xf in self.transfers:
231 if xf.style == "zero":
232 pass
233 elif xf.style == "new":
234 for piece in self.tgt.ReadRangeSet(xf.tgt_ranges):
235 new_f.write(piece)
236 elif xf.style == "diff":
237 src = self.src.ReadRangeSet(xf.src_ranges)
238 tgt = self.tgt.ReadRangeSet(xf.tgt_ranges)
239
240 # We can't compare src and tgt directly because they may have
241 # the same content but be broken up into blocks differently, eg:
242 #
243 # ["he", "llo"] vs ["h", "ello"]
244 #
245 # We want those to compare equal, ideally without having to
246 # actually concatenate the strings (these may be tens of
247 # megabytes).
248
249 src_sha1 = sha1()
250 for p in src:
251 src_sha1.update(p)
252 tgt_sha1 = sha1()
253 tgt_size = 0
254 for p in tgt:
255 tgt_sha1.update(p)
256 tgt_size += len(p)
257
258 if src_sha1.digest() == tgt_sha1.digest():
259 # These are identical; we don't need to generate a patch,
260 # just issue copy commands on the device.
261 xf.style = "move"
262 else:
263 # For files in zip format (eg, APKs, JARs, etc.) we would
264 # like to use imgdiff -z if possible (because it usually
265 # produces significantly smaller patches than bsdiff).
266 # This is permissible if:
267 #
268 # - the source and target files are monotonic (ie, the
269 # data is stored with blocks in increasing order), and
270 # - we haven't removed any blocks from the source set.
271 #
272 # If these conditions are satisfied then appending all the
273 # blocks in the set together in order will produce a valid
274 # zip file (plus possibly extra zeros in the last block),
275 # which is what imgdiff needs to operate. (imgdiff is
276 # fine with extra zeros at the end of the file.)
277 imgdiff = (xf.intact and
278 xf.tgt_name.split(".")[-1].lower()
279 in ("apk", "jar", "zip"))
280 xf.style = "imgdiff" if imgdiff else "bsdiff"
281 diff_q.append((tgt_size, src, tgt, xf, patch_num))
282 patch_num += 1
283
284 else:
285 assert False, "unknown style " + xf.style
286
287 if diff_q:
288 if self.threads > 1:
289 print("Computing patches (using %d threads)..." % (self.threads,))
290 else:
291 print("Computing patches...")
292 diff_q.sort()
293
294 patches = [None] * patch_num
295
296 lock = threading.Lock()
297 def diff_worker():
298 while True:
299 with lock:
300 if not diff_q: return
301 tgt_size, src, tgt, xf, patchnum = diff_q.pop()
302 patch = compute_patch(src, tgt, imgdiff=(xf.style == "imgdiff"))
303 size = len(patch)
304 with lock:
305 patches[patchnum] = (patch, xf)
306 print("%10d %10d (%6.2f%%) %7s %s" % (
307 size, tgt_size, size * 100.0 / tgt_size, xf.style,
308 xf.tgt_name if xf.tgt_name == xf.src_name else (
309 xf.tgt_name + " (from " + xf.src_name + ")")))
310
311 threads = [threading.Thread(target=diff_worker)
312 for i in range(self.threads)]
313 for th in threads:
314 th.start()
315 while threads:
316 threads.pop().join()
317 else:
318 patches = []
319
320 p = 0
321 with open(prefix + ".patch.dat", "wb") as patch_f:
322 for patch, xf in patches:
323 xf.patch_start = p
324 xf.patch_len = len(patch)
325 patch_f.write(patch)
326 p += len(patch)
327
328 def AssertSequenceGood(self):
329 # Simulate the sequences of transfers we will output, and check that:
330 # - we never read a block after writing it, and
331 # - we write every block we care about exactly once.
332
333 # Start with no blocks having been touched yet.
334 touched = RangeSet()
335
336 # Imagine processing the transfers in order.
337 for xf in self.transfers:
338 # Check that the input blocks for this transfer haven't yet been touched.
339 assert not touched.overlaps(xf.src_ranges)
340 # Check that the output blocks for this transfer haven't yet been touched.
341 assert not touched.overlaps(xf.tgt_ranges)
342 # Touch all the blocks written by this transfer.
343 touched = touched.union(xf.tgt_ranges)
344
345 # Check that we've written every target block.
346 assert touched == self.tgt.care_map
347
348 def RemoveBackwardEdges(self):
349 print("Removing backward edges...")
350 in_order = 0
351 out_of_order = 0
352 lost_source = 0
353
354 for xf in self.transfers:
355 io = 0
356 ooo = 0
357 lost = 0
358 size = xf.src_ranges.size()
359 for u in xf.goes_before:
360 # xf should go before u
361 if xf.order < u.order:
362 # it does, hurray!
363 io += 1
364 else:
365 # it doesn't, boo. trim the blocks that u writes from xf's
366 # source, so that xf can go after u.
367 ooo += 1
368 assert xf.src_ranges.overlaps(u.tgt_ranges)
369 xf.src_ranges = xf.src_ranges.subtract(u.tgt_ranges)
370 xf.intact = False
371
372 if xf.style == "diff" and not xf.src_ranges:
373 # nothing left to diff from; treat as new data
374 xf.style = "new"
375
376 lost = size - xf.src_ranges.size()
377 lost_source += lost
378 in_order += io
379 out_of_order += ooo
380
381 print((" %d/%d dependencies (%.2f%%) were violated; "
382 "%d source blocks removed.") %
383 (out_of_order, in_order + out_of_order,
384 (out_of_order * 100.0 / (in_order + out_of_order))
385 if (in_order + out_of_order) else 0.0,
386 lost_source))
387
388 def FindVertexSequence(self):
389 print("Finding vertex sequence...")
390
391 # This is based on "A Fast & Effective Heuristic for the Feedback
392 # Arc Set Problem" by P. Eades, X. Lin, and W.F. Smyth. Think of
393 # it as starting with the digraph G and moving all the vertices to
394 # be on a horizontal line in some order, trying to minimize the
395 # number of edges that end up pointing to the left. Left-pointing
396 # edges will get removed to turn the digraph into a DAG. In this
397 # case each edge has a weight which is the number of source blocks
398 # we'll lose if that edge is removed; we try to minimize the total
399 # weight rather than just the number of edges.
400
401 # Make a copy of the edge set; this copy will get destroyed by the
402 # algorithm.
403 for xf in self.transfers:
404 xf.incoming = xf.goes_after.copy()
405 xf.outgoing = xf.goes_before.copy()
406
407 # We use an OrderedDict instead of just a set so that the output
408 # is repeatable; otherwise it would depend on the hash values of
409 # the transfer objects.
410 G = OrderedDict()
411 for xf in self.transfers:
412 G[xf] = None
413 s1 = deque() # the left side of the sequence, built from left to right
414 s2 = deque() # the right side of the sequence, built from right to left
415
416 while G:
417
418 # Put all sinks at the end of the sequence.
419 while True:
420 sinks = [u for u in G if not u.outgoing]
421 if not sinks: break
422 for u in sinks:
423 s2.appendleft(u)
424 del G[u]
425 for iu in u.incoming:
426 del iu.outgoing[u]
427
428 # Put all the sources at the beginning of the sequence.
429 while True:
430 sources = [u for u in G if not u.incoming]
431 if not sources: break
432 for u in sources:
433 s1.append(u)
434 del G[u]
435 for iu in u.outgoing:
436 del iu.incoming[u]
437
438 if not G: break
439
440 # Find the "best" vertex to put next. "Best" is the one that
441 # maximizes the net difference in source blocks saved we get by
442 # pretending it's a source rather than a sink.
443
444 max_d = None
445 best_u = None
446 for u in G:
447 d = sum(u.outgoing.values()) - sum(u.incoming.values())
448 if best_u is None or d > max_d:
449 max_d = d
450 best_u = u
451
452 u = best_u
453 s1.append(u)
454 del G[u]
455 for iu in u.outgoing:
456 del iu.incoming[u]
457 for iu in u.incoming:
458 del iu.outgoing[u]
459
460 # Now record the sequence in the 'order' field of each transfer,
461 # and by rearranging self.transfers to be in the chosen sequence.
462
463 new_transfers = []
464 for x in itertools.chain(s1, s2):
465 x.order = len(new_transfers)
466 new_transfers.append(x)
467 del x.incoming
468 del x.outgoing
469
470 self.transfers = new_transfers
471
472 def GenerateDigraph(self):
473 print("Generating digraph...")
474 for a in self.transfers:
475 for b in self.transfers:
476 if a is b: continue
477
478 # If the blocks written by A are read by B, then B needs to go before A.
479 i = a.tgt_ranges.intersect(b.src_ranges)
480 if i:
481 size = i.size()
482 b.goes_before[a] = size
483 a.goes_after[b] = size
484
485 def FindTransfers(self):
486 self.transfers = []
487 empty = RangeSet()
488 for tgt_fn, tgt_ranges in self.tgt.file_map.items():
489 if tgt_fn == "__ZERO":
490 # the special "__ZERO" domain is all the blocks not contained
491 # in any file and that are filled with zeros. We have a
492 # special transfer style for zero blocks.
493 src_ranges = self.src.file_map.get("__ZERO", empty)
494 Transfer(tgt_fn, None, tgt_ranges, src_ranges, "zero", self.transfers)
495 continue
496
497 elif tgt_fn in self.src.file_map:
498 # Look for an exact pathname match in the source.
499 Transfer(tgt_fn, tgt_fn, tgt_ranges, self.src.file_map[tgt_fn],
500 "diff", self.transfers)
501 continue
502
503 b = os.path.basename(tgt_fn)
504 if b in self.src_basenames:
505 # Look for an exact basename match in the source.
506 src_fn = self.src_basenames[b]
507 Transfer(tgt_fn, src_fn, tgt_ranges, self.src.file_map[src_fn],
508 "diff", self.transfers)
509 continue
510
511 b = re.sub("[0-9]+", "#", b)
512 if b in self.src_numpatterns:
513 # Look for a 'number pattern' match (a basename match after
514 # all runs of digits are replaced by "#"). (This is useful
515 # for .so files that contain version numbers in the filename
516 # that get bumped.)
517 src_fn = self.src_numpatterns[b]
518 Transfer(tgt_fn, src_fn, tgt_ranges, self.src.file_map[src_fn],
519 "diff", self.transfers)
520 continue
521
522 Transfer(tgt_fn, None, tgt_ranges, empty, "new", self.transfers)
523
524 def AbbreviateSourceNames(self):
525 self.src_basenames = {}
526 self.src_numpatterns = {}
527
528 for k in self.src.file_map.keys():
529 b = os.path.basename(k)
530 self.src_basenames[b] = k
531 b = re.sub("[0-9]+", "#", b)
532 self.src_numpatterns[b] = k
533
534 @staticmethod
535 def AssertPartition(total, seq):
536 """Assert that all the RangeSets in 'seq' form a partition of the
537 'total' RangeSet (ie, they are nonintersecting and their union
538 equals 'total')."""
539 so_far = RangeSet()
540 for i in seq:
541 assert not so_far.overlaps(i)
542 so_far = so_far.union(i)
543 assert so_far == total