Brian Gesiak | fb66991 | 2017-06-29 18:56:25 +0000 | [diff] [blame] | 1 | import sys |
| 2 | import multiprocessing |
| 3 | |
| 4 | |
| 5 | _current = None |
| 6 | _total = None |
| 7 | |
| 8 | |
| 9 | def _init(current, total): |
| 10 | global _current |
| 11 | global _total |
| 12 | _current = current |
| 13 | _total = total |
| 14 | |
| 15 | |
| 16 | def _wrapped_func(func_and_args): |
| 17 | func, argument, should_print_progress = func_and_args |
| 18 | |
| 19 | if should_print_progress: |
| 20 | with _current.get_lock(): |
| 21 | _current.value += 1 |
| 22 | sys.stdout.write('\r\t{} of {}'.format(_current.value, _total.value)) |
Adam Nemet | 6173f58 | 2017-07-14 04:54:26 +0000 | [diff] [blame] | 23 | sys.stdout.flush() |
Brian Gesiak | fb66991 | 2017-06-29 18:56:25 +0000 | [diff] [blame] | 24 | |
| 25 | return func(argument) |
| 26 | |
| 27 | |
| 28 | def pmap(func, iterable, processes, should_print_progress, *args, **kwargs): |
| 29 | """ |
| 30 | A parallel map function that reports on its progress. |
| 31 | |
| 32 | Applies `func` to every item of `iterable` and return a list of the |
| 33 | results. If `processes` is greater than one, a process pool is used to run |
| 34 | the functions in parallel. `should_print_progress` is a boolean value that |
| 35 | indicates whether a string 'N of M' should be printed to indicate how many |
| 36 | of the functions have finished being run. |
| 37 | """ |
| 38 | global _current |
| 39 | global _total |
| 40 | _current = multiprocessing.Value('i', 0) |
| 41 | _total = multiprocessing.Value('i', len(iterable)) |
| 42 | |
| 43 | func_and_args = [(func, arg, should_print_progress,) for arg in iterable] |
Zachary Turner | 135e942 | 2018-01-05 22:05:13 +0000 | [diff] [blame] | 44 | if processes == 1: |
Serge Guelton | e09ecd0 | 2019-01-03 14:12:30 +0000 | [diff] [blame] | 45 | result = list(map(_wrapped_func, func_and_args, *args, **kwargs)) |
Brian Gesiak | fb66991 | 2017-06-29 18:56:25 +0000 | [diff] [blame] | 46 | else: |
| 47 | pool = multiprocessing.Pool(initializer=_init, |
| 48 | initargs=(_current, _total,), |
| 49 | processes=processes) |
| 50 | result = pool.map(_wrapped_func, func_and_args, *args, **kwargs) |
Adam Nemet | 177552c | 2018-02-26 21:15:51 +0000 | [diff] [blame] | 51 | pool.close() |
| 52 | pool.join() |
Brian Gesiak | fb66991 | 2017-06-29 18:56:25 +0000 | [diff] [blame] | 53 | |
| 54 | if should_print_progress: |
| 55 | sys.stdout.write('\r') |
| 56 | return result |