1# http://pyrocko.org - GPLv3 

2# 

3# The Pyrocko Developers, 21st Century 

4# ---|P------/S----------~Lg---------- 

5 

6try: 

7 import queue 

8except ImportError: 

9 import Queue as queue 

10 

11 

12import logging 

13import platform 

14import multiprocessing 

15import traceback 

16import errno 

17 

18 

19logger = logging.getLogger('pyrocko.parimap') 

20 

21 

22def worker( 

23 q_in, q_out, function, eprintignore, pshared, 

24 startup, startup_args, cleanup): 

25 

26 kwargs = {} 

27 if pshared is not None: 

28 kwargs['pshared'] = pshared 

29 

30 if startup is not None: 

31 startup(*startup_args) 

32 

33 while True: 

34 i, args = q_in.get() 

35 if i is None: 

36 if cleanup is not None: 

37 cleanup() 

38 

39 break 

40 

41 res, exception = None, None 

42 try: 

43 res = function(*args, **kwargs) 

44 except Exception as e: 

45 if eprintignore is None or not isinstance(e, eprintignore): 

46 traceback.print_exc() 

47 exception = e 

48 q_out.put((i, res, exception)) 

49 

50 

51def parimap(function, *iterables, **kwargs): 

52 assert all( 

53 k in ( 

54 'nprocs', 'eprintignore', 'pshared', 'startup', 'startup_args', 

55 'cleanup') 

56 for k in kwargs.keys()) 

57 

58 nprocs = kwargs.get('nprocs', None) 

59 eprintignore = kwargs.get('eprintignore', 'all') 

60 pshared = kwargs.get('pshared', None) 

61 startup = kwargs.get('startup', None) 

62 startup_args = kwargs.get('startup_args', ()) 

63 cleanup = kwargs.get('cleanup', None) 

64 

65 if eprintignore == 'all': 

66 eprintignore = None 

67 

68 if platform.system() == 'Windows' and nprocs != 1: 

69 logger.warning( 

70 'The parimap module relies on fork() for parallelism. This does ' 

71 'not work on Windows. Using serial code.') 

72 nprocs = 1 

73 

74 if nprocs == 1: 

75 iterables = list(map(iter, iterables)) 

76 kwargs = {} 

77 if pshared is not None: 

78 kwargs['pshared'] = pshared 

79 

80 if startup is not None: 

81 startup(*startup_args) 

82 

83 try: 

84 while True: 

85 args = [] 

86 for it in iterables: 

87 try: 

88 args.append(next(it)) 

89 except StopIteration: 

90 return 

91 

92 yield function(*args, **kwargs) 

93 

94 finally: 

95 if cleanup is not None: 

96 cleanup() 

97 

98 return 

99 

100 if nprocs is None: 

101 nprocs = multiprocessing.cpu_count() 

102 

103 q_in = multiprocessing.Queue(1) 

104 q_out = multiprocessing.Queue() 

105 

106 procs = [] 

107 

108 results = [] 

109 nrun = 0 

110 nwritten = 0 

111 iout = 0 

112 all_written = False 

113 error_ahead = False 

114 iterables = list(map(iter, iterables)) 

115 while True: 

116 if nrun < nprocs and not all_written and not error_ahead: 

117 args = [] 

118 for it in iterables: 

119 try: 

120 args.append(next(it)) 

121 except StopIteration: 

122 pass 

123 

124 if len(args) == len(iterables): 

125 if len(procs) < nrun + 1: 

126 p = multiprocessing.Process( 

127 target=worker, 

128 args=(q_in, q_out, function, eprintignore, pshared, 

129 startup, startup_args, cleanup)) 

130 p.daemon = True 

131 p.start() 

132 procs.append(p) 

133 

134 q_in.put((nwritten, args)) 

135 nwritten += 1 

136 nrun += 1 

137 else: 

138 all_written = True 

139 [q_in.put((None, None)) for p in procs] 

140 q_in.close() 

141 

142 try: 

143 while nrun > 0: 

144 if nrun < nprocs and not all_written and not error_ahead: 

145 results.append(q_out.get_nowait()) 

146 else: 

147 while True: 

148 try: 

149 results.append(q_out.get()) 

150 break 

151 except IOError as e: 

152 if e.errno != errno.EINTR: 

153 raise 

154 

155 nrun -= 1 

156 

157 except queue.Empty: 

158 pass 

159 

160 if results: 

161 results.sort() 

162 # check for error ahead to prevent further enqueuing 

163 if any(exc for (_, _, exc) in results): 

164 error_ahead = True 

165 

166 while results: 

167 (i, r, exc) = results[0] 

168 if i == iout: 

169 results.pop(0) 

170 if exc is not None: 

171 if not all_written: 

172 [q_in.put((None, None)) for p in procs] 

173 q_in.close() 

174 raise exc 

175 else: 

176 yield r 

177 

178 iout += 1 

179 else: 

180 break 

181 

182 if all_written and nrun == 0: 

183 break 

184 

185 [p.join() for p in procs] 

186 return