Coverage for /usr/local/lib/python3.11/dist-packages/pyrocko/parimap.py: 81%

122 statements  

« prev     ^ index     » next       coverage.py v6.5.0, created at 2023-10-11 11:01 +0000

1# http://pyrocko.org - GPLv3 

2# 

3# The Pyrocko Developers, 21st Century 

4# ---|P------/S----------~Lg---------- 

5 

6''' 

7Parallel :py:func:`map` implementation based on :py:mod:`multiprocessing`. 

8''' 

9 

10try: 

11 import queue 

12except ImportError: 

13 import Queue as queue 

14 

15 

16import logging 

17import multiprocessing 

18import traceback 

19import errno 

20 

21 

22logger = logging.getLogger('pyrocko.parimap') 

23 

24 

25def worker( 

26 q_in, q_out, function, eprintignore, pshared, 

27 startup, startup_args, cleanup): 

28 

29 kwargs = {} 

30 if pshared is not None: 

31 kwargs['pshared'] = pshared 

32 

33 if startup is not None: 

34 startup(*startup_args) 

35 

36 while True: 

37 i, args = q_in.get() 

38 if i is None: 

39 if cleanup is not None: 

40 cleanup() 

41 

42 break 

43 

44 res, exception = None, None 

45 try: 

46 res = function(*args, **kwargs) 

47 except Exception as e: 

48 if eprintignore is None or not isinstance(e, eprintignore): 

49 traceback.print_exc() 

50 exception = e 

51 q_out.put((i, res, exception)) 

52 

53 

54def parimap(function, *iterables, **kwargs): 

55 assert all( 

56 k in ( 

57 'nprocs', 'eprintignore', 'pshared', 'startup', 'startup_args', 

58 'cleanup') 

59 for k in kwargs.keys()) 

60 

61 nprocs = kwargs.get('nprocs', None) 

62 eprintignore = kwargs.get('eprintignore', 'all') 

63 pshared = kwargs.get('pshared', None) 

64 startup = kwargs.get('startup', None) 

65 startup_args = kwargs.get('startup_args', ()) 

66 cleanup = kwargs.get('cleanup', None) 

67 

68 if eprintignore == 'all': 

69 eprintignore = None 

70 

71 if nprocs == 1: 

72 iterables = list(map(iter, iterables)) 

73 kwargs = {} 

74 if pshared is not None: 

75 kwargs['pshared'] = pshared 

76 

77 if startup is not None: 

78 startup(*startup_args) 

79 

80 try: 

81 while True: 

82 args = [] 

83 for it in iterables: 

84 try: 

85 args.append(next(it)) 

86 except StopIteration: 

87 return 

88 

89 yield function(*args, **kwargs) 

90 

91 finally: 

92 if cleanup is not None: 

93 cleanup() 

94 

95 return 

96 

97 if nprocs is None: 

98 nprocs = multiprocessing.cpu_count() 

99 

100 q_in = multiprocessing.Queue(1) 

101 q_out = multiprocessing.Queue() 

102 

103 procs = [] 

104 

105 results = [] 

106 nrun = 0 

107 nwritten = 0 

108 iout = 0 

109 all_written = False 

110 error_ahead = False 

111 iterables = list(map(iter, iterables)) 

112 while True: 

113 if nrun < nprocs and not all_written and not error_ahead: 

114 args = [] 

115 for it in iterables: 

116 try: 

117 args.append(next(it)) 

118 except StopIteration: 

119 pass 

120 

121 if len(args) == len(iterables): 

122 if len(procs) < nrun + 1: 

123 p = multiprocessing.Process( 

124 target=worker, 

125 args=(q_in, q_out, function, eprintignore, pshared, 

126 startup, startup_args, cleanup)) 

127 p.daemon = True 

128 p.start() 

129 procs.append(p) 

130 

131 q_in.put((nwritten, args)) 

132 nwritten += 1 

133 nrun += 1 

134 else: 

135 all_written = True 

136 [q_in.put((None, None)) for p in procs] 

137 q_in.close() 

138 

139 try: 

140 while nrun > 0: 

141 if nrun < nprocs and not all_written and not error_ahead: 

142 results.append(q_out.get_nowait()) 

143 else: 

144 while True: 

145 try: 

146 results.append(q_out.get()) 

147 break 

148 except IOError as e: 

149 if e.errno != errno.EINTR: 

150 raise 

151 

152 nrun -= 1 

153 

154 except queue.Empty: 

155 pass 

156 

157 if results: 

158 results.sort() 

159 # check for error ahead to prevent further enqueuing 

160 if any(exc for (_, _, exc) in results): 

161 error_ahead = True 

162 

163 while results: 

164 (i, r, exc) = results[0] 

165 if i == iout: 

166 results.pop(0) 

167 if exc is not None: 

168 if not all_written: 

169 [q_in.put((None, None)) for p in procs] 

170 q_in.close() 

171 raise exc 

172 else: 

173 yield r 

174 

175 iout += 1 

176 else: 

177 break 

178 

179 if all_written and nrun == 0: 

180 break 

181 

182 [p.join() for p in procs] 

183 return