Coverage for /usr/local/lib/python3.11/dist-packages/pyrocko/gui/snuffler/snufflings/deep_picker.py: 29%

182 statements  

« prev     ^ index     » next       coverage.py v6.5.0, created at 2024-09-24 11:43 +0000

1# https://pyrocko.org - GPLv3 

2# 

3# The Pyrocko Developers, 21st Century 

4# ---|P------/S----------~Lg---------- 

5 

6from typing import Any 

7import logging 

8from pyrocko.trace import Trace, NoData 

9from pyrocko import obspy_compat as compat 

10 

11from ..snuffling import Param, Snuffling, Choice, Switch 

12from ..marker import PhaseMarker 

13 

14try: 

15 import seisbench.models as sbm 

16 import torch 

17 from obspy import Stream 

18except ImportError: 

19 sbm = None 

20 Stream = None 

21 

22 

23h = 3600.0 

24logger = logging.getLogger(__name__) 

25 

26 

27detectionmethods = ( 

28 'original', 'ethz', 'instance', 'scedc', 'stead', 'geofon', 

29 'neic', 'cascadia', 'cms', 'jcms', 'jcs', 'jms', 'mexico', 

30 'nankai', 'san_andreas') 

31 

32networks = ( 

33 'PhaseNet', 'EQTransformer', 'GPD', 'LFEDetect' 

34) 

35 

36 

37MODEL_CACHE: dict[str, Any] = {} 

38 

39 

40def get_blinding_samples(model: "sbm.WaveformModel") -> tuple[int, int]: 

41 try: 

42 return model.default_args["blinding"] 

43 except KeyError: 

44 return model._annotate_args["blinding"][1] 

45 

46 

47class DeepDetector(Snuffling): 

48 

49 def __init__(self, *args, **kwargs): 

50 Snuffling.__init__(self, *args, **kwargs) 

51 self.training_model = 'original' 

52 try: 

53 from seisbench.util.annotations import PickList 

54 self.old_method: str = None 

55 self.old_network: str = None 

56 self.pick_list: PickList | None = PickList() 

57 

58 except ImportError: 

59 self.old_method = None 

60 self.old_network = None 

61 self.pick_list = None 

62 

63 def help(self) -> str: 

64 return ''' 

65<html> 

66<head> 

67<style type="text/css"> 

68 body { margin-left:10px }; 

69</style> 

70</head> 

71<body> 

72<h1 align="center">SeisBench: ML Picker</h1> 

73<p> 

74 Automatic detection of P- and S-Phases in the given traces, using various 

75 pre-trained ML models from SeisBench.<br/> 

76<p> 

77<b>Parameters:</b><br /> 

78 <b>&middot; P threshold</b> 

79 - Define a trigger threshold for the P-Phase detection <br /> 

80 <b>&middot; S threshold</b> 

81 - Define a trigger threshold for the S-Phase detection <br /> 

82 <b>&middot; Detection method</b> 

83 - Choose the pretrained model, used for detection. <br /> 

84</p> 

85<p> 

86 <span style="color:red">P-Phases</span> are marked with red markers, <span 

87 style="color:green>S-Phases</span> with green markers. 

88<p> 

89 More information about SeisBench can be found <a 

90 href="https://seisbench.readthedocs.io/en/stable/index.html">on the 

91 seisbench website</a>. 

92</p> 

93</body> 

94</html> 

95 ''' 

96 

97 def setup(self) -> None: 

98 self.set_name('SeisBench: ML Picker') 

99 self.add_parameter( 

100 Choice( 

101 'Network', 

102 'network', 

103 default='PhaseNet', 

104 choices=networks, 

105 ) 

106 ) 

107 

108 self.add_parameter( 

109 Choice( 

110 'Training model', 

111 'training_model', 

112 default='original', 

113 choices=detectionmethods, 

114 ) 

115 ) 

116 

117 self.add_parameter( 

118 Param( 

119 'P threshold', 

120 'p_threshold', 

121 default=self.get_default_threshold('P'), 

122 minimum=0.0, 

123 maximum=1.0, 

124 ) 

125 ) 

126 self.add_parameter( 

127 Param( 

128 'S threshold', 's_threshold', 

129 default=self.get_default_threshold('S'), 

130 minimum=0.0, 

131 maximum=1.0, 

132 ) 

133 ) 

134 

135 self.add_parameter( 

136 Switch( 

137 'Show annotation traces', 'show_annotation_traces', 

138 default=False 

139 ) 

140 ) 

141 self.add_parameter( 

142 Switch( 

143 'Use predefined filters', 'use_predefined_filters', 

144 default=True 

145 ) 

146 ) 

147 self.add_parameter( 

148 Param( 

149 'Rescaling factor', 'scale_factor', 

150 default=1.0, minimum=0.1, maximum=10.0 

151 ) 

152 ) 

153 

154 self.set_live_update(True) 

155 

156 def get_default_threshold(self, phase: str) -> float: 

157 if sbm is None or self.training_model == 'original': 

158 return 0.3 

159 

160 else: 

161 model = self.get_model(self.network, self.training_model) 

162 if phase == 'S': 

163 return model.default_args['S_threshold'] 

164 elif phase == 'P': 

165 return model.default_args['P_threshold'] 

166 

167 def get_model(self, network: str, model: str) -> Any: 

168 if sbm is None: 

169 raise ImportError( 

170 'SeisBench is not installed. Install to use this plugin.') 

171 

172 if model in MODEL_CACHE: 

173 return MODEL_CACHE[(network, model)] 

174 seisbench_model = eval(f'sbm.{network}.from_pretrained("{model}")') 

175 try: 

176 seisbench_model = seisbench_model.to('cuda') 

177 except (RuntimeError, AssertionError): 

178 logger.info('CUDA not available, using CPU') 

179 pass 

180 seisbench_model.eval() 

181 try: 

182 seisbench_model = torch.compile( 

183 seisbench_model, mode='max-autotune') 

184 except RuntimeError: 

185 logger.info('Torch compile failed') 

186 pass 

187 MODEL_CACHE[(network, model)] = seisbench_model 

188 return seisbench_model 

189 

190 def set_default_thresholds(self) -> None: 

191 if self.training_model == 'original': 

192 self.set_parameter('p_threshold', 0.3) 

193 self.set_parameter('s_threshold', 0.3) 

194 elif self.training_model in detectionmethods[-8:]: 

195 self.set_parameter('p_threshold', 0.3) 

196 self.set_parameter('s_threshold', 0.3) 

197 else: 

198 self.set_parameter('p_threshold', self.get_default_threshold('P')) 

199 self.set_parameter('s_threshold', self.get_default_threshold('S')) 

200 

201 def content_changed(self) -> None: 

202 if self._live_update: 

203 self.call() 

204 

205 def panel_visibility_changed(self, visible: bool) -> None: 

206 if visible: 

207 self._connect_signals() 

208 else: 

209 self._disconnect_signals() 

210 

211 def _connect_signals(self) -> None: 

212 viewer = self.get_viewer() 

213 viewer.pile_has_changed_signal.connect(self.content_changed) 

214 viewer.frequency_filter_changed.connect(self.content_changed) 

215 

216 def _disconnect_signals(self) -> None: 

217 viewer = self.get_viewer() 

218 viewer.pile_has_changed_signal.disconnect(self.content_changed) 

219 viewer.frequency_filter_changed.disconnect(self.content_changed) 

220 

221 def call(self) -> None: 

222 self.cleanup() 

223 self.adjust_thresholds() 

224 model = self.get_model(self.network, self.training_model) 

225 

226 tinc = 300.0 

227 tpad = 1.0 

228 

229 tpad_filter = 0.0 

230 if self.use_predefined_filters: 

231 fmin = self.get_viewer().highpass 

232 tpad_filter = 0.0 if fmin is None else 2.0/fmin 

233 

234 for batch in self.chopper_selected_traces( 

235 tinc=tinc, 

236 tpad=tpad + tpad_filter, 

237 fallback=True, 

238 mode='visible', 

239 progress='Calculating SeisBench detections...', 

240 responsive=True, 

241 style='batch', 

242 ): 

243 traces = batch.traces 

244 

245 if not traces: 

246 continue 

247 

248 wmin, wmax = batch.tmin, batch.tmax 

249 

250 for tr in traces: 

251 tr.meta = {'tabu': True} 

252 

253 if self.use_predefined_filters: 

254 traces = [self.apply_filter(tr, tpad_filter) for tr in traces] 

255 

256 stream = Stream([compat.to_obspy_trace(tr) for tr in traces]) 

257 tr_starttimes: dict[str, float] = {} 

258 if self.scale_factor != 1: 

259 for tr in stream: 

260 tr.stats.sampling_rate /= self.scale_factor 

261 s = tr.stats 

262 tr_nsl = '.'.join((s.network, s.station, s.location)) 

263 tr_starttimes[tr_nsl] = s.starttime.timestamp 

264 

265 output_classify = model.classify( 

266 stream, 

267 P_threshold=self.p_threshold, 

268 S_threshold=self.s_threshold, 

269 ) 

270 

271 if self.show_annotation_traces: 

272 output_annotation = model.annotate( 

273 stream, 

274 P_threshold=self.p_threshold, 

275 S_threshold=self.s_threshold, 

276 ) 

277 

278 if self.scale_factor != 1: 

279 for tr in output_annotation: 

280 tr.stats.sampling_rate *= self.scale_factor 

281 blinding_samples = max(get_blinding_samples(model)) 

282 blinding_seconds = (blinding_samples / 100.0) * \ 

283 (1.0 - 1 / self.scale_factor) 

284 tr.stats.starttime -= blinding_seconds 

285 

286 traces_raw = compat.to_pyrocko_traces(output_annotation) 

287 ano_traces = [] 

288 for tr in traces_raw: 

289 if tr.channel[-1] != 'N': 

290 tr = tr.copy() 

291 tr.chop(wmin, wmax) 

292 tr.meta = {'tabu': True, 'annotation': True} 

293 ano_traces.append(tr) 

294 

295 self._disconnect_signals() 

296 self.add_traces(ano_traces) 

297 self._connect_signals() 

298 

299 self.pick_list = output_classify.picks 

300 

301 markers = [] 

302 for pick in output_classify.picks: 

303 

304 tpeak = pick.peak_time.timestamp 

305 if self.scale_factor != 1: 

306 tr_starttime = tr_starttimes[pick.trace_id] 

307 tpeak = tr_starttime + \ 

308 (pick.peak_time.timestamp - tr_starttime) \ 

309 / self.scale_factor 

310 

311 if wmin <= tpeak < wmax: 

312 codes = tuple(pick.trace_id.split('.')) + ('*',) 

313 markers.append(PhaseMarker( 

314 [codes], 

315 tmin=tpeak, 

316 tmax=tpeak, 

317 kind=0 if pick.phase == 'P' else 1, 

318 phasename=pick.phase, 

319 incidence_angle=pick.peak_value, 

320 )) 

321 

322 self.add_markers(markers) 

323 

324 def adjust_thresholds(self) -> None: 

325 method = self.get_parameter_value('training_model') 

326 network = self.get_parameter_value('network') 

327 if method != self.old_method or network != self.old_network: 

328 if (network == 'LFEDetect') \ 

329 and (method not in detectionmethods[-8:]): 

330 logger.info( 

331 'The selected model is not compatible with LFEDetect ' 

332 'please select a model from the last 8 models in the ' 

333 'list. Default is cascadia.') 

334 self.set_parameter('training_model', 'cascadia') 

335 elif (network != 'LFEDetect') \ 

336 and (method in detectionmethods[-8:]): 

337 logger.info( 

338 'The selected model is not compatible with the selected ' 

339 'network. Default is original.') 

340 self.set_parameter('training_model', 'original') 

341 self.set_default_thresholds() 

342 self.old_method = method 

343 self.old_network = network 

344 

345 def apply_filter(self, tr: Trace, tcut: float) -> Trace: 

346 viewer = self.get_viewer() 

347 if viewer.lowpass is not None: 

348 tr.lowpass(4, viewer.lowpass, nyquist_exception=False) 

349 if viewer.highpass is not None: 

350 tr.highpass(4, viewer.highpass, nyquist_exception=False) 

351 try: 

352 tr.chop(tr.tmin + tcut, tr.tmax - tcut) 

353 except NoData: 

354 pass 

355 return tr 

356 

357 

358def __snufflings__(): 

359 return [DeepDetector()]