1# http://pyrocko.org - GPLv3 

2# 

3# The Pyrocko Developers, 21st Century 

4# ---|P------/S----------~Lg---------- 

5 

6try: 

7 import queue 

8except ImportError: 

9 import Queue as queue 

10 

11 

12import logging 

13import multiprocessing 

14import traceback 

15import errno 

16 

17 

18logger = logging.getLogger('pyrocko.parimap') 

19 

20 

21def worker( 

22 q_in, q_out, function, eprintignore, pshared, 

23 startup, startup_args, cleanup): 

24 

25 kwargs = {} 

26 if pshared is not None: 

27 kwargs['pshared'] = pshared 

28 

29 if startup is not None: 

30 startup(*startup_args) 

31 

32 while True: 

33 i, args = q_in.get() 

34 if i is None: 

35 if cleanup is not None: 

36 cleanup() 

37 

38 break 

39 

40 res, exception = None, None 

41 try: 

42 res = function(*args, **kwargs) 

43 except Exception as e: 

44 if eprintignore is None or not isinstance(e, eprintignore): 

45 traceback.print_exc() 

46 exception = e 

47 q_out.put((i, res, exception)) 

48 

49 

50def parimap(function, *iterables, **kwargs): 

51 assert all( 

52 k in ( 

53 'nprocs', 'eprintignore', 'pshared', 'startup', 'startup_args', 

54 'cleanup') 

55 for k in kwargs.keys()) 

56 

57 nprocs = kwargs.get('nprocs', None) 

58 eprintignore = kwargs.get('eprintignore', 'all') 

59 pshared = kwargs.get('pshared', None) 

60 startup = kwargs.get('startup', None) 

61 startup_args = kwargs.get('startup_args', ()) 

62 cleanup = kwargs.get('cleanup', None) 

63 

64 if eprintignore == 'all': 

65 eprintignore = None 

66 

67 if nprocs == 1: 

68 iterables = list(map(iter, iterables)) 

69 kwargs = {} 

70 if pshared is not None: 

71 kwargs['pshared'] = pshared 

72 

73 if startup is not None: 

74 startup(*startup_args) 

75 

76 try: 

77 while True: 

78 args = [] 

79 for it in iterables: 

80 try: 

81 args.append(next(it)) 

82 except StopIteration: 

83 return 

84 

85 yield function(*args, **kwargs) 

86 

87 finally: 

88 if cleanup is not None: 

89 cleanup() 

90 

91 return 

92 

93 if nprocs is None: 

94 nprocs = multiprocessing.cpu_count() 

95 

96 q_in = multiprocessing.Queue(1) 

97 q_out = multiprocessing.Queue() 

98 

99 procs = [] 

100 

101 results = [] 

102 nrun = 0 

103 nwritten = 0 

104 iout = 0 

105 all_written = False 

106 error_ahead = False 

107 iterables = list(map(iter, iterables)) 

108 while True: 

109 if nrun < nprocs and not all_written and not error_ahead: 

110 args = [] 

111 for it in iterables: 

112 try: 

113 args.append(next(it)) 

114 except StopIteration: 

115 pass 

116 

117 if len(args) == len(iterables): 

118 if len(procs) < nrun + 1: 

119 p = multiprocessing.Process( 

120 target=worker, 

121 args=(q_in, q_out, function, eprintignore, pshared, 

122 startup, startup_args, cleanup)) 

123 p.daemon = True 

124 p.start() 

125 procs.append(p) 

126 

127 q_in.put((nwritten, args)) 

128 nwritten += 1 

129 nrun += 1 

130 else: 

131 all_written = True 

132 [q_in.put((None, None)) for p in procs] 

133 q_in.close() 

134 

135 try: 

136 while nrun > 0: 

137 if nrun < nprocs and not all_written and not error_ahead: 

138 results.append(q_out.get_nowait()) 

139 else: 

140 while True: 

141 try: 

142 results.append(q_out.get()) 

143 break 

144 except IOError as e: 

145 if e.errno != errno.EINTR: 

146 raise 

147 

148 nrun -= 1 

149 

150 except queue.Empty: 

151 pass 

152 

153 if results: 

154 results.sort() 

155 # check for error ahead to prevent further enqueuing 

156 if any(exc for (_, _, exc) in results): 

157 error_ahead = True 

158 

159 while results: 

160 (i, r, exc) = results[0] 

161 if i == iout: 

162 results.pop(0) 

163 if exc is not None: 

164 if not all_written: 

165 [q_in.put((None, None)) for p in procs] 

166 q_in.close() 

167 raise exc 

168 else: 

169 yield r 

170 

171 iout += 1 

172 else: 

173 break 

174 

175 if all_written and nrun == 0: 

176 break 

177 

178 [p.join() for p in procs] 

179 return