Coverage for /usr/local/lib/python3.13/dist-packages/pyrocko/gui/snuffler/snufflings/seisbench_picker.py: 30%
197 statements
« prev ^ index » next coverage.py v7.6.0, created at 2025-12-04 10:41 +0000
« prev ^ index » next coverage.py v7.6.0, created at 2025-12-04 10:41 +0000
1# https://pyrocko.org - GPLv3
2#
3# The Pyrocko Developers, 21st Century
4# ---|P------/S----------~Lg----------
6from typing import Dict, Tuple
7import logging
8from os.path import commonprefix
9from pyrocko.trace import Trace, NoData
10from pyrocko import obspy_compat as compat
12from ..snuffling import Param, Snuffling, Choice, Switch
13from ..marker import PhaseMarker
15try:
16 import seisbench.models as sbm
17 from seisbench.models import WaveformModel
18 import torch
19 from obspy import Stream
20except ImportError:
21 sbm = None
22 Stream = None
25HOUR = 3600.0
26logger = logging.getLogger(__name__)
28ML_NETWORKS = (
29 'PhaseNet', 'EQTransformer', 'GPD', 'LFEDetect'
30)
31TRAINED_MODELS = (
32 'original', 'ethz', 'instance', 'scedc', 'stead', 'geofon',
33 'neic', 'cascadia', 'cms', 'jcms', 'jcs', 'jms', 'mexico',
34 'nankai', 'san_andreas'
35)
37MODEL_CACHE: Dict[Tuple[str, str], "WaveformModel"] = {}
40def get_blinding_samples(model: "WaveformModel") -> Tuple[int, int]:
41 try:
42 return model.default_args["blinding"]
43 except KeyError:
44 return model._annotate_args["blinding"][1]
47class SeisBenchDetector(Snuffling):
48 network: str
49 training_model: str
51 p_threshold: float
52 s_threshold: float
53 scale_factor: float
55 show_annotation_traces: bool
56 use_predefined_filters: bool
58 def __init__(self, *args, **kwargs):
59 Snuffling.__init__(self, *args, **kwargs)
60 self.training_model = 'original'
61 try:
62 from seisbench.util.annotations import PickList
63 self.old_method: str = ""
64 self.old_network: str = ""
65 self.pick_list: PickList | None = PickList()
67 except ImportError:
68 self.old_method = ""
69 self.old_network = ""
70 self.pick_list = None
72 def help(self) -> str:
73 return '''
74<html>
75<head>
76<style type="text/css">
77 body { margin-left:10px };
78</style>
79</head>
80<body>
81<h1 align="center">SeisBench: ML Picker</h1>
82<p>
83 Automatic detection of P- and S-Phases in the given traces, using various
84 pre-trained ML models from SeisBench.<br/>
85<p>
86<b>Parameters:</b><br />
87 <b>· P threshold</b>
88 - Define a trigger threshold for the P-Phase detection <br />
89 <b>· S threshold</b>
90 - Define a trigger threshold for the S-Phase detection <br />
91 <b>· Detection method</b>
92 - Choose the pretrained model, used for detection. <br />
93 <b>· Rescaling</b>
94 - Factor to stretch and compress the waveforms before inference <br />
95</p>
96<p>
97 <span style="color:red">P-Phases</span> are marked with red markers, <span
98 style="color:green>S-Phases</span> with green markers.
99<p>
100 More information about SeisBench can be found <a
101 href="https://seisbench.readthedocs.io/en/stable/index.html">on the
102 SeisBench website</a>.
103</p>
104</body>
105</html>'''
107 def setup(self) -> None:
108 self.set_name('SeisBench: ML Picker')
109 self.add_parameter(
110 Choice(
111 'Network',
112 'network',
113 default='PhaseNet',
114 choices=ML_NETWORKS,
115 )
116 )
118 self.add_parameter(
119 Choice(
120 'Training model',
121 'training_model',
122 default='original',
123 choices=TRAINED_MODELS,
124 )
125 )
127 self.add_parameter(
128 Param(
129 'P threshold',
130 'p_threshold',
131 default=self.get_default_threshold('P'),
132 minimum=0.0,
133 maximum=1.0,
134 )
135 )
136 self.add_parameter(
137 Param(
138 'S threshold', 's_threshold',
139 default=self.get_default_threshold('S'),
140 minimum=0.0,
141 maximum=1.0,
142 )
143 )
145 self.add_parameter(
146 Switch(
147 'Show annotation traces', 'show_annotation_traces',
148 default=False
149 )
150 )
151 self.add_parameter(
152 Switch(
153 'Use pre-defined filters', 'use_predefined_filters',
154 default=True
155 )
156 )
157 self.add_parameter(
158 Param(
159 'Rescaling factor', 'scale_factor',
160 default=1.0,
161 minimum=0.1,
162 maximum=10.0
163 )
164 )
166 self.set_live_update(True)
168 def get_default_threshold(self, phase: str) -> float:
169 if sbm is None or self.training_model == 'original':
170 return 0.3
172 else:
173 model = self.get_model(self.network, self.training_model)
174 if phase == 'S':
175 return model.default_args['S_threshold']
176 elif phase == 'P':
177 return model.default_args['P_threshold']
179 def get_model(self, network: str, model: str) -> "WaveformModel":
180 if sbm is None:
181 raise ImportError(
182 'SeisBench is not installed. Install to use this plugin.')
183 key = (network, model)
184 if key in MODEL_CACHE:
185 return MODEL_CACHE[key]
186 seisbench_model = eval(f'sbm.{network}.from_pretrained("{model}")')
187 try:
188 seisbench_model = seisbench_model.to('cuda')
189 except (RuntimeError, AssertionError):
190 logger.info('CUDA not available, using CPU')
192 seisbench_model.eval()
193 try:
194 seisbench_model = torch.compile(
195 seisbench_model, mode='max-autotune')
196 except RuntimeError:
197 logger.info('Torch compile failed')
199 MODEL_CACHE[key] = seisbench_model
200 return seisbench_model
202 def set_default_thresholds(self) -> None:
203 if self.training_model == 'original':
204 self.set_parameter('p_threshold', 0.3)
205 self.set_parameter('s_threshold', 0.3)
206 elif self.training_model in TRAINED_MODELS[-8:]:
207 self.set_parameter('p_threshold', 0.3)
208 self.set_parameter('s_threshold', 0.3)
209 else:
210 self.set_parameter('p_threshold', self.get_default_threshold('P'))
211 self.set_parameter('s_threshold', self.get_default_threshold('S'))
213 def content_changed(self) -> None:
214 if self._live_update:
215 self.call()
217 def panel_visibility_changed(self, visible: bool) -> None:
218 if visible:
219 self._connect_signals()
220 else:
221 self._disconnect_signals()
223 def _connect_signals(self) -> None:
224 viewer = self.get_viewer()
225 viewer.pile_has_changed_signal.connect(self.content_changed)
226 viewer.frequency_filter_changed.connect(self.content_changed)
228 def _disconnect_signals(self) -> None:
229 viewer = self.get_viewer()
230 viewer.pile_has_changed_signal.disconnect(self.content_changed)
231 viewer.frequency_filter_changed.disconnect(self.content_changed)
233 def call(self) -> None:
234 self.cleanup()
235 self.get_viewer().clean_update()
236 self.adjust_thresholds()
238 model = self.get_model(self.network, self.training_model)
240 tinc = 300.0
241 tpad = 1.0
243 tpad_filter = 0.0
244 if self.use_predefined_filters:
245 fmin = self.get_viewer().highpass
246 tpad_filter = 0.0 if fmin is None else 2.0/fmin
248 for batch in self.chopper_selected_traces(
249 tinc=tinc,
250 tpad=tpad + tpad_filter,
251 fallback=True,
252 mode='visible',
253 progress='Calculating SeisBench detections...',
254 responsive=True,
255 style='batch',
256 ):
257 traces: list[Trace] = []
258 for trace in batch.traces:
259 if trace.meta and trace.meta.get('tabu', False):
260 continue
261 traces.append(trace)
263 if not traces:
264 continue
266 wmin, wmax = batch.tmin, batch.tmax
268 if self.use_predefined_filters:
269 traces = [self.apply_filter(tr, tpad_filter) for tr in traces]
271 stream = Stream([compat.to_obspy_trace(tr) for tr in traces])
272 tr_starttimes: Dict[str, float] = {}
273 if self.scale_factor != 1:
274 for tr in stream:
275 tr.stats.sampling_rate /= self.scale_factor
276 s = tr.stats
277 tr_nsl = '.'.join((s.network, s.station, s.location))
278 tr_starttimes[tr_nsl] = s.starttime.timestamp
280 output_classify = model.classify(
281 stream,
282 P_threshold=self.p_threshold,
283 S_threshold=self.s_threshold,
284 )
286 if self.show_annotation_traces:
287 output_annotation = model.annotate(
288 stream,
289 P_threshold=self.p_threshold,
290 S_threshold=self.s_threshold,
291 )
293 if self.scale_factor != 1:
294 for tr in output_annotation:
295 tr.stats.sampling_rate *= self.scale_factor
296 blinding_samples = max(get_blinding_samples(model))
297 # 100 Hz is the native sampling rate of the model
298 blinding_seconds = (blinding_samples / 100.0) * \
299 (1.0 - 1.0 / self.scale_factor)
300 tr.stats.starttime -= blinding_seconds
302 annotated_traces = []
303 for tr in compat.to_pyrocko_traces(output_annotation):
304 if tr.channel.endswith('N'):
305 continue
306 tr = tr.copy()
307 tr.chop(wmin, wmax)
308 tr.meta = {'tabu': True, 'annotation': True}
309 annotated_traces.append(tr)
311 self._disconnect_signals()
312 self.add_traces(annotated_traces)
313 self._connect_signals()
315 self.pick_list = output_classify.picks
317 def get_channel(trace_id: str) -> str:
318 network, station, location = trace_id.split('.')
319 nsl = (network, station, location)
320 return commonprefix(
321 [tr.channel for tr in traces if tr.nslc_id[:-1] == nsl]) \
322 + '*'
324 markers = []
325 for pick in output_classify.picks:
327 tpeak = pick.peak_time.timestamp
328 if self.scale_factor != 1:
329 tr_starttime = tr_starttimes[pick.trace_id]
330 tpeak = tr_starttime + \
331 (pick.peak_time.timestamp - tr_starttime) \
332 / self.scale_factor
334 if wmin <= tpeak < wmax:
335 codes = tuple(pick.trace_id.split('.')) + \
336 (get_channel(pick.trace_id),)
337 markers.append(PhaseMarker(
338 [codes],
339 tmin=tpeak,
340 tmax=tpeak,
341 kind=0 if pick.phase == 'P' else 1,
342 phasename=pick.phase,
343 incidence_angle=pick.peak_value,
344 ))
346 self.add_markers(markers)
348 def adjust_thresholds(self) -> None:
349 method = self.get_parameter_value('training_model')
350 network = self.get_parameter_value('network')
352 if method != self.old_method or network != self.old_network:
353 if (network == 'LFEDetect') \
354 and (method not in TRAINED_MODELS[-8:]):
355 logger.info(
356 'The selected model is not compatible with LFEDetect '
357 'please select a model from the last 8 models in the '
358 'list. Default is cascadia.')
359 self.set_parameter('training_model', 'cascadia')
360 elif (network != 'LFEDetect') \
361 and (method in TRAINED_MODELS[-8:]):
362 logger.info(
363 'The selected model is not compatible with the selected '
364 'network. Default is original.')
365 self.set_parameter('training_model', 'original')
366 self.set_default_thresholds()
367 self.old_method = method
368 self.old_network = network
370 def apply_filter(self, tr: Trace, tcut: float) -> Trace:
371 viewer = self.get_viewer()
372 if viewer.lowpass is not None:
373 tr.lowpass(4, viewer.lowpass, nyquist_exception=False)
374 if viewer.highpass is not None:
375 tr.highpass(4, viewer.highpass, nyquist_exception=False)
376 try:
377 tr.chop(tr.tmin + tcut, tr.tmax - tcut)
378 except NoData:
379 pass
380 return tr
383def __snufflings__():
384 return [SeisBenchDetector()]