1# http://pyrocko.org - GPLv3
2#
3# The Pyrocko Developers, 21st Century
4# ---|P------/S----------~Lg----------
6try:
7 import queue
8except ImportError:
9 import Queue as queue
12import logging
13import platform
14import multiprocessing
15import traceback
16import errno
19logger = logging.getLogger('pyrocko.parimap')
22def worker(
23 q_in, q_out, function, eprintignore, pshared,
24 startup, startup_args, cleanup):
26 kwargs = {}
27 if pshared is not None:
28 kwargs['pshared'] = pshared
30 if startup is not None:
31 startup(*startup_args)
33 while True:
34 i, args = q_in.get()
35 if i is None:
36 if cleanup is not None:
37 cleanup()
39 break
41 res, exception = None, None
42 try:
43 res = function(*args, **kwargs)
44 except Exception as e:
45 if eprintignore is None or not isinstance(e, eprintignore):
46 traceback.print_exc()
47 exception = e
48 q_out.put((i, res, exception))
51def parimap(function, *iterables, **kwargs):
52 assert all(
53 k in (
54 'nprocs', 'eprintignore', 'pshared', 'startup', 'startup_args',
55 'cleanup')
56 for k in kwargs.keys())
58 nprocs = kwargs.get('nprocs', None)
59 eprintignore = kwargs.get('eprintignore', 'all')
60 pshared = kwargs.get('pshared', None)
61 startup = kwargs.get('startup', None)
62 startup_args = kwargs.get('startup_args', ())
63 cleanup = kwargs.get('cleanup', None)
65 if eprintignore == 'all':
66 eprintignore = None
68 if platform.system() == 'Windows' and nprocs != 1:
69 logger.warn(
70 'The parimap module relies on fork() for parallelism. This does '
71 'not work on Windows. Using serial code.')
72 nprocs = 1
74 if nprocs == 1:
75 iterables = list(map(iter, iterables))
76 kwargs = {}
77 if pshared is not None:
78 kwargs['pshared'] = pshared
80 if startup is not None:
81 startup(*startup_args)
83 try:
84 while True:
85 args = []
86 for it in iterables:
87 try:
88 args.append(next(it))
89 except StopIteration:
90 return
92 yield function(*args, **kwargs)
94 finally:
95 if cleanup is not None:
96 cleanup()
98 return
100 if nprocs is None:
101 nprocs = multiprocessing.cpu_count()
103 q_in = multiprocessing.Queue(1)
104 q_out = multiprocessing.Queue()
106 procs = []
108 results = []
109 nrun = 0
110 nwritten = 0
111 iout = 0
112 all_written = False
113 error_ahead = False
114 iterables = list(map(iter, iterables))
115 while True:
116 if nrun < nprocs and not all_written and not error_ahead:
117 args = []
118 for it in iterables:
119 try:
120 args.append(next(it))
121 except StopIteration:
122 pass
124 if len(args) == len(iterables):
125 if len(procs) < nrun + 1:
126 p = multiprocessing.Process(
127 target=worker,
128 args=(q_in, q_out, function, eprintignore, pshared,
129 startup, startup_args, cleanup))
130 p.daemon = True
131 p.start()
132 procs.append(p)
134 q_in.put((nwritten, args))
135 nwritten += 1
136 nrun += 1
137 else:
138 all_written = True
139 [q_in.put((None, None)) for p in procs]
140 q_in.close()
142 try:
143 while nrun > 0:
144 if nrun < nprocs and not all_written and not error_ahead:
145 results.append(q_out.get_nowait())
146 else:
147 while True:
148 try:
149 results.append(q_out.get())
150 break
151 except IOError as e:
152 if e.errno != errno.EINTR:
153 raise
155 nrun -= 1
157 except queue.Empty:
158 pass
160 if results:
161 results.sort()
162 # check for error ahead to prevent further enqueuing
163 if any(exc for (_, _, exc) in results):
164 error_ahead = True
166 while results:
167 (i, r, exc) = results[0]
168 if i == iout:
169 results.pop(0)
170 if exc is not None:
171 if not all_written:
172 [q_in.put((None, None)) for p in procs]
173 q_in.close()
174 raise exc
175 else:
176 yield r
178 iout += 1
179 else:
180 break
182 if all_written and nrun == 0:
183 break
185 [p.join() for p in procs]
186 return