1# http://pyrocko.org - GPLv3
2#
3# The Pyrocko Developers, 21st Century
4# ---|P------/S----------~Lg----------
6try:
7 import queue
8except ImportError:
9 import Queue as queue
12import logging
13import multiprocessing
14import traceback
15import errno
18logger = logging.getLogger('pyrocko.parimap')
21def worker(
22 q_in, q_out, function, eprintignore, pshared,
23 startup, startup_args, cleanup):
25 kwargs = {}
26 if pshared is not None:
27 kwargs['pshared'] = pshared
29 if startup is not None:
30 startup(*startup_args)
32 while True:
33 i, args = q_in.get()
34 if i is None:
35 if cleanup is not None:
36 cleanup()
38 break
40 res, exception = None, None
41 try:
42 res = function(*args, **kwargs)
43 except Exception as e:
44 if eprintignore is None or not isinstance(e, eprintignore):
45 traceback.print_exc()
46 exception = e
47 q_out.put((i, res, exception))
50def parimap(function, *iterables, **kwargs):
51 assert all(
52 k in (
53 'nprocs', 'eprintignore', 'pshared', 'startup', 'startup_args',
54 'cleanup')
55 for k in kwargs.keys())
57 nprocs = kwargs.get('nprocs', None)
58 eprintignore = kwargs.get('eprintignore', 'all')
59 pshared = kwargs.get('pshared', None)
60 startup = kwargs.get('startup', None)
61 startup_args = kwargs.get('startup_args', ())
62 cleanup = kwargs.get('cleanup', None)
64 if eprintignore == 'all':
65 eprintignore = None
67 if nprocs == 1:
68 iterables = list(map(iter, iterables))
69 kwargs = {}
70 if pshared is not None:
71 kwargs['pshared'] = pshared
73 if startup is not None:
74 startup(*startup_args)
76 try:
77 while True:
78 args = []
79 for it in iterables:
80 try:
81 args.append(next(it))
82 except StopIteration:
83 return
85 yield function(*args, **kwargs)
87 finally:
88 if cleanup is not None:
89 cleanup()
91 return
93 if nprocs is None:
94 nprocs = multiprocessing.cpu_count()
96 q_in = multiprocessing.Queue(1)
97 q_out = multiprocessing.Queue()
99 procs = []
101 results = []
102 nrun = 0
103 nwritten = 0
104 iout = 0
105 all_written = False
106 error_ahead = False
107 iterables = list(map(iter, iterables))
108 while True:
109 if nrun < nprocs and not all_written and not error_ahead:
110 args = []
111 for it in iterables:
112 try:
113 args.append(next(it))
114 except StopIteration:
115 pass
117 if len(args) == len(iterables):
118 if len(procs) < nrun + 1:
119 p = multiprocessing.Process(
120 target=worker,
121 args=(q_in, q_out, function, eprintignore, pshared,
122 startup, startup_args, cleanup))
123 p.daemon = True
124 p.start()
125 procs.append(p)
127 q_in.put((nwritten, args))
128 nwritten += 1
129 nrun += 1
130 else:
131 all_written = True
132 [q_in.put((None, None)) for p in procs]
133 q_in.close()
135 try:
136 while nrun > 0:
137 if nrun < nprocs and not all_written and not error_ahead:
138 results.append(q_out.get_nowait())
139 else:
140 while True:
141 try:
142 results.append(q_out.get())
143 break
144 except IOError as e:
145 if e.errno != errno.EINTR:
146 raise
148 nrun -= 1
150 except queue.Empty:
151 pass
153 if results:
154 results.sort()
155 # check for error ahead to prevent further enqueuing
156 if any(exc for (_, _, exc) in results):
157 error_ahead = True
159 while results:
160 (i, r, exc) = results[0]
161 if i == iout:
162 results.pop(0)
163 if exc is not None:
164 if not all_written:
165 [q_in.put((None, None)) for p in procs]
166 q_in.close()
167 raise exc
168 else:
169 yield r
171 iout += 1
172 else:
173 break
175 if all_written and nrun == 0:
176 break
178 [p.join() for p in procs]
179 return