Coverage for /usr/local/lib/python3.13/dist-packages/pyrocko/gui/snuffler/snufflings/seisbench_picker.py: 30%

197 statements  

« prev     ^ index     » next       coverage.py v7.6.0, created at 2025-12-04 10:41 +0000

1# https://pyrocko.org - GPLv3 

2# 

3# The Pyrocko Developers, 21st Century 

4# ---|P------/S----------~Lg---------- 

5 

6from typing import Dict, Tuple 

7import logging 

8from os.path import commonprefix 

9from pyrocko.trace import Trace, NoData 

10from pyrocko import obspy_compat as compat 

11 

12from ..snuffling import Param, Snuffling, Choice, Switch 

13from ..marker import PhaseMarker 

14 

15try: 

16 import seisbench.models as sbm 

17 from seisbench.models import WaveformModel 

18 import torch 

19 from obspy import Stream 

20except ImportError: 

21 sbm = None 

22 Stream = None 

23 

24 

25HOUR = 3600.0 

26logger = logging.getLogger(__name__) 

27 

28ML_NETWORKS = ( 

29 'PhaseNet', 'EQTransformer', 'GPD', 'LFEDetect' 

30) 

31TRAINED_MODELS = ( 

32 'original', 'ethz', 'instance', 'scedc', 'stead', 'geofon', 

33 'neic', 'cascadia', 'cms', 'jcms', 'jcs', 'jms', 'mexico', 

34 'nankai', 'san_andreas' 

35) 

36 

37MODEL_CACHE: Dict[Tuple[str, str], "WaveformModel"] = {} 

38 

39 

40def get_blinding_samples(model: "WaveformModel") -> Tuple[int, int]: 

41 try: 

42 return model.default_args["blinding"] 

43 except KeyError: 

44 return model._annotate_args["blinding"][1] 

45 

46 

47class SeisBenchDetector(Snuffling): 

48 network: str 

49 training_model: str 

50 

51 p_threshold: float 

52 s_threshold: float 

53 scale_factor: float 

54 

55 show_annotation_traces: bool 

56 use_predefined_filters: bool 

57 

58 def __init__(self, *args, **kwargs): 

59 Snuffling.__init__(self, *args, **kwargs) 

60 self.training_model = 'original' 

61 try: 

62 from seisbench.util.annotations import PickList 

63 self.old_method: str = "" 

64 self.old_network: str = "" 

65 self.pick_list: PickList | None = PickList() 

66 

67 except ImportError: 

68 self.old_method = "" 

69 self.old_network = "" 

70 self.pick_list = None 

71 

72 def help(self) -> str: 

73 return ''' 

74<html> 

75<head> 

76<style type="text/css"> 

77 body { margin-left:10px }; 

78</style> 

79</head> 

80<body> 

81<h1 align="center">SeisBench: ML Picker</h1> 

82<p> 

83 Automatic detection of P- and S-Phases in the given traces, using various 

84 pre-trained ML models from SeisBench.<br/> 

85<p> 

86<b>Parameters:</b><br /> 

87 <b>&middot; P threshold</b> 

88 - Define a trigger threshold for the P-Phase detection <br /> 

89 <b>&middot; S threshold</b> 

90 - Define a trigger threshold for the S-Phase detection <br /> 

91 <b>&middot; Detection method</b> 

92 - Choose the pretrained model, used for detection. <br /> 

93 <b>&middot; Rescaling</b> 

94 - Factor to stretch and compress the waveforms before inference <br /> 

95</p> 

96<p> 

97 <span style="color:red">P-Phases</span> are marked with red markers, <span 

98 style="color:green>S-Phases</span> with green markers. 

99<p> 

100 More information about SeisBench can be found <a 

101 href="https://seisbench.readthedocs.io/en/stable/index.html">on the 

102 SeisBench website</a>. 

103</p> 

104</body> 

105</html>''' 

106 

107 def setup(self) -> None: 

108 self.set_name('SeisBench: ML Picker') 

109 self.add_parameter( 

110 Choice( 

111 'Network', 

112 'network', 

113 default='PhaseNet', 

114 choices=ML_NETWORKS, 

115 ) 

116 ) 

117 

118 self.add_parameter( 

119 Choice( 

120 'Training model', 

121 'training_model', 

122 default='original', 

123 choices=TRAINED_MODELS, 

124 ) 

125 ) 

126 

127 self.add_parameter( 

128 Param( 

129 'P threshold', 

130 'p_threshold', 

131 default=self.get_default_threshold('P'), 

132 minimum=0.0, 

133 maximum=1.0, 

134 ) 

135 ) 

136 self.add_parameter( 

137 Param( 

138 'S threshold', 's_threshold', 

139 default=self.get_default_threshold('S'), 

140 minimum=0.0, 

141 maximum=1.0, 

142 ) 

143 ) 

144 

145 self.add_parameter( 

146 Switch( 

147 'Show annotation traces', 'show_annotation_traces', 

148 default=False 

149 ) 

150 ) 

151 self.add_parameter( 

152 Switch( 

153 'Use pre-defined filters', 'use_predefined_filters', 

154 default=True 

155 ) 

156 ) 

157 self.add_parameter( 

158 Param( 

159 'Rescaling factor', 'scale_factor', 

160 default=1.0, 

161 minimum=0.1, 

162 maximum=10.0 

163 ) 

164 ) 

165 

166 self.set_live_update(True) 

167 

168 def get_default_threshold(self, phase: str) -> float: 

169 if sbm is None or self.training_model == 'original': 

170 return 0.3 

171 

172 else: 

173 model = self.get_model(self.network, self.training_model) 

174 if phase == 'S': 

175 return model.default_args['S_threshold'] 

176 elif phase == 'P': 

177 return model.default_args['P_threshold'] 

178 

179 def get_model(self, network: str, model: str) -> "WaveformModel": 

180 if sbm is None: 

181 raise ImportError( 

182 'SeisBench is not installed. Install to use this plugin.') 

183 key = (network, model) 

184 if key in MODEL_CACHE: 

185 return MODEL_CACHE[key] 

186 seisbench_model = eval(f'sbm.{network}.from_pretrained("{model}")') 

187 try: 

188 seisbench_model = seisbench_model.to('cuda') 

189 except (RuntimeError, AssertionError): 

190 logger.info('CUDA not available, using CPU') 

191 

192 seisbench_model.eval() 

193 try: 

194 seisbench_model = torch.compile( 

195 seisbench_model, mode='max-autotune') 

196 except RuntimeError: 

197 logger.info('Torch compile failed') 

198 

199 MODEL_CACHE[key] = seisbench_model 

200 return seisbench_model 

201 

202 def set_default_thresholds(self) -> None: 

203 if self.training_model == 'original': 

204 self.set_parameter('p_threshold', 0.3) 

205 self.set_parameter('s_threshold', 0.3) 

206 elif self.training_model in TRAINED_MODELS[-8:]: 

207 self.set_parameter('p_threshold', 0.3) 

208 self.set_parameter('s_threshold', 0.3) 

209 else: 

210 self.set_parameter('p_threshold', self.get_default_threshold('P')) 

211 self.set_parameter('s_threshold', self.get_default_threshold('S')) 

212 

213 def content_changed(self) -> None: 

214 if self._live_update: 

215 self.call() 

216 

217 def panel_visibility_changed(self, visible: bool) -> None: 

218 if visible: 

219 self._connect_signals() 

220 else: 

221 self._disconnect_signals() 

222 

223 def _connect_signals(self) -> None: 

224 viewer = self.get_viewer() 

225 viewer.pile_has_changed_signal.connect(self.content_changed) 

226 viewer.frequency_filter_changed.connect(self.content_changed) 

227 

228 def _disconnect_signals(self) -> None: 

229 viewer = self.get_viewer() 

230 viewer.pile_has_changed_signal.disconnect(self.content_changed) 

231 viewer.frequency_filter_changed.disconnect(self.content_changed) 

232 

233 def call(self) -> None: 

234 self.cleanup() 

235 self.get_viewer().clean_update() 

236 self.adjust_thresholds() 

237 

238 model = self.get_model(self.network, self.training_model) 

239 

240 tinc = 300.0 

241 tpad = 1.0 

242 

243 tpad_filter = 0.0 

244 if self.use_predefined_filters: 

245 fmin = self.get_viewer().highpass 

246 tpad_filter = 0.0 if fmin is None else 2.0/fmin 

247 

248 for batch in self.chopper_selected_traces( 

249 tinc=tinc, 

250 tpad=tpad + tpad_filter, 

251 fallback=True, 

252 mode='visible', 

253 progress='Calculating SeisBench detections...', 

254 responsive=True, 

255 style='batch', 

256 ): 

257 traces: list[Trace] = [] 

258 for trace in batch.traces: 

259 if trace.meta and trace.meta.get('tabu', False): 

260 continue 

261 traces.append(trace) 

262 

263 if not traces: 

264 continue 

265 

266 wmin, wmax = batch.tmin, batch.tmax 

267 

268 if self.use_predefined_filters: 

269 traces = [self.apply_filter(tr, tpad_filter) for tr in traces] 

270 

271 stream = Stream([compat.to_obspy_trace(tr) for tr in traces]) 

272 tr_starttimes: Dict[str, float] = {} 

273 if self.scale_factor != 1: 

274 for tr in stream: 

275 tr.stats.sampling_rate /= self.scale_factor 

276 s = tr.stats 

277 tr_nsl = '.'.join((s.network, s.station, s.location)) 

278 tr_starttimes[tr_nsl] = s.starttime.timestamp 

279 

280 output_classify = model.classify( 

281 stream, 

282 P_threshold=self.p_threshold, 

283 S_threshold=self.s_threshold, 

284 ) 

285 

286 if self.show_annotation_traces: 

287 output_annotation = model.annotate( 

288 stream, 

289 P_threshold=self.p_threshold, 

290 S_threshold=self.s_threshold, 

291 ) 

292 

293 if self.scale_factor != 1: 

294 for tr in output_annotation: 

295 tr.stats.sampling_rate *= self.scale_factor 

296 blinding_samples = max(get_blinding_samples(model)) 

297 # 100 Hz is the native sampling rate of the model 

298 blinding_seconds = (blinding_samples / 100.0) * \ 

299 (1.0 - 1.0 / self.scale_factor) 

300 tr.stats.starttime -= blinding_seconds 

301 

302 annotated_traces = [] 

303 for tr in compat.to_pyrocko_traces(output_annotation): 

304 if tr.channel.endswith('N'): 

305 continue 

306 tr = tr.copy() 

307 tr.chop(wmin, wmax) 

308 tr.meta = {'tabu': True, 'annotation': True} 

309 annotated_traces.append(tr) 

310 

311 self._disconnect_signals() 

312 self.add_traces(annotated_traces) 

313 self._connect_signals() 

314 

315 self.pick_list = output_classify.picks 

316 

317 def get_channel(trace_id: str) -> str: 

318 network, station, location = trace_id.split('.') 

319 nsl = (network, station, location) 

320 return commonprefix( 

321 [tr.channel for tr in traces if tr.nslc_id[:-1] == nsl]) \ 

322 + '*' 

323 

324 markers = [] 

325 for pick in output_classify.picks: 

326 

327 tpeak = pick.peak_time.timestamp 

328 if self.scale_factor != 1: 

329 tr_starttime = tr_starttimes[pick.trace_id] 

330 tpeak = tr_starttime + \ 

331 (pick.peak_time.timestamp - tr_starttime) \ 

332 / self.scale_factor 

333 

334 if wmin <= tpeak < wmax: 

335 codes = tuple(pick.trace_id.split('.')) + \ 

336 (get_channel(pick.trace_id),) 

337 markers.append(PhaseMarker( 

338 [codes], 

339 tmin=tpeak, 

340 tmax=tpeak, 

341 kind=0 if pick.phase == 'P' else 1, 

342 phasename=pick.phase, 

343 incidence_angle=pick.peak_value, 

344 )) 

345 

346 self.add_markers(markers) 

347 

348 def adjust_thresholds(self) -> None: 

349 method = self.get_parameter_value('training_model') 

350 network = self.get_parameter_value('network') 

351 

352 if method != self.old_method or network != self.old_network: 

353 if (network == 'LFEDetect') \ 

354 and (method not in TRAINED_MODELS[-8:]): 

355 logger.info( 

356 'The selected model is not compatible with LFEDetect ' 

357 'please select a model from the last 8 models in the ' 

358 'list. Default is cascadia.') 

359 self.set_parameter('training_model', 'cascadia') 

360 elif (network != 'LFEDetect') \ 

361 and (method in TRAINED_MODELS[-8:]): 

362 logger.info( 

363 'The selected model is not compatible with the selected ' 

364 'network. Default is original.') 

365 self.set_parameter('training_model', 'original') 

366 self.set_default_thresholds() 

367 self.old_method = method 

368 self.old_network = network 

369 

370 def apply_filter(self, tr: Trace, tcut: float) -> Trace: 

371 viewer = self.get_viewer() 

372 if viewer.lowpass is not None: 

373 tr.lowpass(4, viewer.lowpass, nyquist_exception=False) 

374 if viewer.highpass is not None: 

375 tr.highpass(4, viewer.highpass, nyquist_exception=False) 

376 try: 

377 tr.chop(tr.tmin + tcut, tr.tmax - tcut) 

378 except NoData: 

379 pass 

380 return tr 

381 

382 

383def __snufflings__(): 

384 return [SeisBenchDetector()]