Coverage for /usr/local/lib/python3.11/dist-packages/pyrocko/gui/snuffler/snufflings/deep_picker.py: 29%
182 statements
« prev ^ index » next coverage.py v6.5.0, created at 2024-09-24 11:43 +0000
« prev ^ index » next coverage.py v6.5.0, created at 2024-09-24 11:43 +0000
1# https://pyrocko.org - GPLv3
2#
3# The Pyrocko Developers, 21st Century
4# ---|P------/S----------~Lg----------
6from typing import Any
7import logging
8from pyrocko.trace import Trace, NoData
9from pyrocko import obspy_compat as compat
11from ..snuffling import Param, Snuffling, Choice, Switch
12from ..marker import PhaseMarker
14try:
15 import seisbench.models as sbm
16 import torch
17 from obspy import Stream
18except ImportError:
19 sbm = None
20 Stream = None
23h = 3600.0
24logger = logging.getLogger(__name__)
27detectionmethods = (
28 'original', 'ethz', 'instance', 'scedc', 'stead', 'geofon',
29 'neic', 'cascadia', 'cms', 'jcms', 'jcs', 'jms', 'mexico',
30 'nankai', 'san_andreas')
32networks = (
33 'PhaseNet', 'EQTransformer', 'GPD', 'LFEDetect'
34)
37MODEL_CACHE: dict[str, Any] = {}
40def get_blinding_samples(model: "sbm.WaveformModel") -> tuple[int, int]:
41 try:
42 return model.default_args["blinding"]
43 except KeyError:
44 return model._annotate_args["blinding"][1]
47class DeepDetector(Snuffling):
49 def __init__(self, *args, **kwargs):
50 Snuffling.__init__(self, *args, **kwargs)
51 self.training_model = 'original'
52 try:
53 from seisbench.util.annotations import PickList
54 self.old_method: str = None
55 self.old_network: str = None
56 self.pick_list: PickList | None = PickList()
58 except ImportError:
59 self.old_method = None
60 self.old_network = None
61 self.pick_list = None
63 def help(self) -> str:
64 return '''
65<html>
66<head>
67<style type="text/css">
68 body { margin-left:10px };
69</style>
70</head>
71<body>
72<h1 align="center">SeisBench: ML Picker</h1>
73<p>
74 Automatic detection of P- and S-Phases in the given traces, using various
75 pre-trained ML models from SeisBench.<br/>
76<p>
77<b>Parameters:</b><br />
78 <b>· P threshold</b>
79 - Define a trigger threshold for the P-Phase detection <br />
80 <b>· S threshold</b>
81 - Define a trigger threshold for the S-Phase detection <br />
82 <b>· Detection method</b>
83 - Choose the pretrained model, used for detection. <br />
84</p>
85<p>
86 <span style="color:red">P-Phases</span> are marked with red markers, <span
87 style="color:green>S-Phases</span> with green markers.
88<p>
89 More information about SeisBench can be found <a
90 href="https://seisbench.readthedocs.io/en/stable/index.html">on the
91 seisbench website</a>.
92</p>
93</body>
94</html>
95 '''
97 def setup(self) -> None:
98 self.set_name('SeisBench: ML Picker')
99 self.add_parameter(
100 Choice(
101 'Network',
102 'network',
103 default='PhaseNet',
104 choices=networks,
105 )
106 )
108 self.add_parameter(
109 Choice(
110 'Training model',
111 'training_model',
112 default='original',
113 choices=detectionmethods,
114 )
115 )
117 self.add_parameter(
118 Param(
119 'P threshold',
120 'p_threshold',
121 default=self.get_default_threshold('P'),
122 minimum=0.0,
123 maximum=1.0,
124 )
125 )
126 self.add_parameter(
127 Param(
128 'S threshold', 's_threshold',
129 default=self.get_default_threshold('S'),
130 minimum=0.0,
131 maximum=1.0,
132 )
133 )
135 self.add_parameter(
136 Switch(
137 'Show annotation traces', 'show_annotation_traces',
138 default=False
139 )
140 )
141 self.add_parameter(
142 Switch(
143 'Use predefined filters', 'use_predefined_filters',
144 default=True
145 )
146 )
147 self.add_parameter(
148 Param(
149 'Rescaling factor', 'scale_factor',
150 default=1.0, minimum=0.1, maximum=10.0
151 )
152 )
154 self.set_live_update(True)
156 def get_default_threshold(self, phase: str) -> float:
157 if sbm is None or self.training_model == 'original':
158 return 0.3
160 else:
161 model = self.get_model(self.network, self.training_model)
162 if phase == 'S':
163 return model.default_args['S_threshold']
164 elif phase == 'P':
165 return model.default_args['P_threshold']
167 def get_model(self, network: str, model: str) -> Any:
168 if sbm is None:
169 raise ImportError(
170 'SeisBench is not installed. Install to use this plugin.')
172 if model in MODEL_CACHE:
173 return MODEL_CACHE[(network, model)]
174 seisbench_model = eval(f'sbm.{network}.from_pretrained("{model}")')
175 try:
176 seisbench_model = seisbench_model.to('cuda')
177 except (RuntimeError, AssertionError):
178 logger.info('CUDA not available, using CPU')
179 pass
180 seisbench_model.eval()
181 try:
182 seisbench_model = torch.compile(
183 seisbench_model, mode='max-autotune')
184 except RuntimeError:
185 logger.info('Torch compile failed')
186 pass
187 MODEL_CACHE[(network, model)] = seisbench_model
188 return seisbench_model
190 def set_default_thresholds(self) -> None:
191 if self.training_model == 'original':
192 self.set_parameter('p_threshold', 0.3)
193 self.set_parameter('s_threshold', 0.3)
194 elif self.training_model in detectionmethods[-8:]:
195 self.set_parameter('p_threshold', 0.3)
196 self.set_parameter('s_threshold', 0.3)
197 else:
198 self.set_parameter('p_threshold', self.get_default_threshold('P'))
199 self.set_parameter('s_threshold', self.get_default_threshold('S'))
201 def content_changed(self) -> None:
202 if self._live_update:
203 self.call()
205 def panel_visibility_changed(self, visible: bool) -> None:
206 if visible:
207 self._connect_signals()
208 else:
209 self._disconnect_signals()
211 def _connect_signals(self) -> None:
212 viewer = self.get_viewer()
213 viewer.pile_has_changed_signal.connect(self.content_changed)
214 viewer.frequency_filter_changed.connect(self.content_changed)
216 def _disconnect_signals(self) -> None:
217 viewer = self.get_viewer()
218 viewer.pile_has_changed_signal.disconnect(self.content_changed)
219 viewer.frequency_filter_changed.disconnect(self.content_changed)
221 def call(self) -> None:
222 self.cleanup()
223 self.adjust_thresholds()
224 model = self.get_model(self.network, self.training_model)
226 tinc = 300.0
227 tpad = 1.0
229 tpad_filter = 0.0
230 if self.use_predefined_filters:
231 fmin = self.get_viewer().highpass
232 tpad_filter = 0.0 if fmin is None else 2.0/fmin
234 for batch in self.chopper_selected_traces(
235 tinc=tinc,
236 tpad=tpad + tpad_filter,
237 fallback=True,
238 mode='visible',
239 progress='Calculating SeisBench detections...',
240 responsive=True,
241 style='batch',
242 ):
243 traces = batch.traces
245 if not traces:
246 continue
248 wmin, wmax = batch.tmin, batch.tmax
250 for tr in traces:
251 tr.meta = {'tabu': True}
253 if self.use_predefined_filters:
254 traces = [self.apply_filter(tr, tpad_filter) for tr in traces]
256 stream = Stream([compat.to_obspy_trace(tr) for tr in traces])
257 tr_starttimes: dict[str, float] = {}
258 if self.scale_factor != 1:
259 for tr in stream:
260 tr.stats.sampling_rate /= self.scale_factor
261 s = tr.stats
262 tr_nsl = '.'.join((s.network, s.station, s.location))
263 tr_starttimes[tr_nsl] = s.starttime.timestamp
265 output_classify = model.classify(
266 stream,
267 P_threshold=self.p_threshold,
268 S_threshold=self.s_threshold,
269 )
271 if self.show_annotation_traces:
272 output_annotation = model.annotate(
273 stream,
274 P_threshold=self.p_threshold,
275 S_threshold=self.s_threshold,
276 )
278 if self.scale_factor != 1:
279 for tr in output_annotation:
280 tr.stats.sampling_rate *= self.scale_factor
281 blinding_samples = max(get_blinding_samples(model))
282 blinding_seconds = (blinding_samples / 100.0) * \
283 (1.0 - 1 / self.scale_factor)
284 tr.stats.starttime -= blinding_seconds
286 traces_raw = compat.to_pyrocko_traces(output_annotation)
287 ano_traces = []
288 for tr in traces_raw:
289 if tr.channel[-1] != 'N':
290 tr = tr.copy()
291 tr.chop(wmin, wmax)
292 tr.meta = {'tabu': True, 'annotation': True}
293 ano_traces.append(tr)
295 self._disconnect_signals()
296 self.add_traces(ano_traces)
297 self._connect_signals()
299 self.pick_list = output_classify.picks
301 markers = []
302 for pick in output_classify.picks:
304 tpeak = pick.peak_time.timestamp
305 if self.scale_factor != 1:
306 tr_starttime = tr_starttimes[pick.trace_id]
307 tpeak = tr_starttime + \
308 (pick.peak_time.timestamp - tr_starttime) \
309 / self.scale_factor
311 if wmin <= tpeak < wmax:
312 codes = tuple(pick.trace_id.split('.')) + ('*',)
313 markers.append(PhaseMarker(
314 [codes],
315 tmin=tpeak,
316 tmax=tpeak,
317 kind=0 if pick.phase == 'P' else 1,
318 phasename=pick.phase,
319 incidence_angle=pick.peak_value,
320 ))
322 self.add_markers(markers)
324 def adjust_thresholds(self) -> None:
325 method = self.get_parameter_value('training_model')
326 network = self.get_parameter_value('network')
327 if method != self.old_method or network != self.old_network:
328 if (network == 'LFEDetect') \
329 and (method not in detectionmethods[-8:]):
330 logger.info(
331 'The selected model is not compatible with LFEDetect '
332 'please select a model from the last 8 models in the '
333 'list. Default is cascadia.')
334 self.set_parameter('training_model', 'cascadia')
335 elif (network != 'LFEDetect') \
336 and (method in detectionmethods[-8:]):
337 logger.info(
338 'The selected model is not compatible with the selected '
339 'network. Default is original.')
340 self.set_parameter('training_model', 'original')
341 self.set_default_thresholds()
342 self.old_method = method
343 self.old_network = network
345 def apply_filter(self, tr: Trace, tcut: float) -> Trace:
346 viewer = self.get_viewer()
347 if viewer.lowpass is not None:
348 tr.lowpass(4, viewer.lowpass, nyquist_exception=False)
349 if viewer.highpass is not None:
350 tr.highpass(4, viewer.highpass, nyquist_exception=False)
351 try:
352 tr.chop(tr.tmin + tcut, tr.tmax - tcut)
353 except NoData:
354 pass
355 return tr
358def __snufflings__():
359 return [DeepDetector()]