Coverage for /usr/local/lib/python3.11/dist-packages/pyrocko/multitrace.py: 34%
188 statements
« prev ^ index » next coverage.py v6.5.0, created at 2024-02-27 10:58 +0000
« prev ^ index » next coverage.py v6.5.0, created at 2024-02-27 10:58 +0000
1# http://pyrocko.org - GPLv3
2#
3# The Pyrocko Developers, 21st Century
4# ---|P------/S----------~Lg----------
6'''
7Multi-component waveform data model.
8'''
10import logging
11from functools import partialmethod
13import numpy as num
14import numpy.ma as ma
15from scipy import signal
17from . import trace, util
18from .trace import Trace, AboveNyquist, _get_cached_filter_coeffs
19from .guts import Object, Float, Timestamp, List
20from .guts_array import Array
21from .squirrel import \
22 CodesNSLCE, SensorGrouping
24from .squirrel.operators.base import ReplaceComponentTranslation
26logger = logging.getLogger('pyrocko.multitrace')
29class MultiTrace(Object):
30 '''
31 Container for multi-component waveforms with common time span and sampling.
33 Instances of this class can be used to efficiently represent
34 multi-component waveforms of a single sensor or of a sensor array. The data
35 samples are stored in a single 2D array where the first index runs over
36 components and the second index over time. Metadata contains sampling rate,
37 start-time and :py:class:`~pyrocko.squirrel.model.CodesNSLCE` identifiers
38 for the contained traces.
40 The :py:gattr:`data` is held as a NumPy :py:class:`numpy.ma.MaskedArray`
41 where missing or invalid data is masked.
43 :param traces:
44 If given, construct multi-trace from given single-component waveforms
45 (see :py:func:`~pyrocko.trace.get_traces_data_as_array`) and ignore
46 any other arguments.
47 :type traces:
48 :py:class:`list` of :py:class:`~pyrocko.trace.Trace`
49 '''
51 codes = List.T(
52 CodesNSLCE.T(),
53 help='List of codes identifying the components.')
54 data = Array.T(
55 shape=(None, None),
56 help='Array containing the data samples indexed as '
57 '``(icomponent, isample)``.')
58 tmin = Timestamp.T(
59 default=Timestamp.D('1970-01-01 00:00:00'),
60 help='Start time.')
61 deltat = Float.T(
62 default=1.0,
63 help='Sampling interval [s]')
65 def __init__(
66 self,
67 traces=None,
68 assemble='concatenate',
69 data=None,
70 codes=None,
71 tmin=None,
72 deltat=None):
74 if traces is not None:
75 if len(traces) == 0:
76 data = ma.zeros((0, 0))
77 else:
78 if assemble == 'merge':
79 data, codes, tmin, deltat \
80 = trace.merge_traces_data_as_array(traces)
82 elif assemble == 'concatenate':
83 data = ma.array(trace.get_traces_data_as_array(traces))
84 codes = [tr.codes for tr in traces]
85 tmin = traces[0].tmin
86 deltat = traces[0].deltat
88 self.ntraces, self.nsamples = data.shape
90 if codes is None:
91 codes = [CodesNSLCE()] * self.ntraces
93 if len(codes) != self.ntraces:
94 raise ValueError(
95 'MultiTrace construction: mismatch between number of traces '
96 'and number of codes given.')
98 if deltat is None:
99 deltat = self.T.deltat.default()
101 if tmin is None:
102 tmin = self.T.tmin.default()
104 Object.__init__(self, codes=codes, data=data, tmin=tmin, deltat=deltat)
106 @property
107 def summary_codes(self):
108 if self.codes:
109 if len(self.codes) == 1:
110 return str(self.codes[0])
111 elif len(self.codes) == 2:
112 return '%s, %s' % (self.codes[0], self.codes[-1])
113 else:
114 return '%s, ..., %s' % (self.codes[0], self.codes[-1])
115 else:
116 return 'None'
118 @property
119 def summary_entries(self):
120 return (
121 self.__class__.__name__,
122 str(self.data.shape[0]),
123 str(self.data.shape[1]),
124 str(self.data.dtype),
125 str(self.deltat),
126 util.time_to_str(self.tmin),
127 util.time_to_str(self.tmax),
128 self.summary_codes)
130 @property
131 def summary(self):
132 '''
133 Textual summary of the waveform's metadata attributes.
134 '''
135 return util.fmt_summary(
136 self.summary_entries, (10, 5, 7, 10, 10, 25, 25, 50))
138 def __len__(self):
139 '''
140 Get number of components.
141 '''
142 return self.ntraces
144 def __getitem__(self, i):
145 '''
146 Get single component waveform (shared data).
148 :param i:
149 Component index.
150 :type i:
151 int
152 '''
153 return self.get_trace(i)
155 def copy(self, data='copy'):
156 '''
157 Create a copy
159 :param data:
160 ``'copy'`` to deeply copy the data, or ``'reference'`` to create
161 a shallow copy, referencing the original data.
162 :type data:
163 str
164 '''
166 if isinstance(data, str):
167 assert data in ('copy', 'reference')
168 data = self.data.copy() if data == 'copy' else self.data
169 else:
170 assert isinstance(data, ma.MaskedArray)
172 return MultiTrace(
173 data=data,
174 codes=list(self.codes),
175 tmin=self.tmin,
176 deltat=self.deltat)
178 @property
179 def tmax(self):
180 '''
181 End time (time of last sample, read-only).
182 '''
183 return self.tmin + (self.nsamples - 1) * self.deltat
185 def get_trace(self, i, span=slice(None)):
186 '''
187 Get single component waveform (shared data).
189 :param i:
190 Component index.
191 :type i:
192 int
193 '''
195 network, station, location, channel, extra = self.codes[i]
196 return Trace(
197 network=network,
198 station=station,
199 location=location,
200 channel=channel,
201 extra=extra,
202 tmin=self.tmin + (span.start or 0) * self.deltat,
203 deltat=self.deltat,
204 ydata=self.data.data[i, span])
206 def iter_valid_traces(self):
207 if self.data.mask is ma.nomask:
208 yield from self
209 else:
210 for irow, row in enumerate(
211 ma.notmasked_contiguous(self.data, axis=1)):
212 for slice in row:
213 yield self.get_trace(irow, slice)
215 def get_traces(self):
216 return list(self)
218 def get_valid_traces(self):
219 return list(self.iter_valid_traces())
221 def snuffle(self, what='valid'):
222 '''
223 Show in Snuffler.
224 '''
226 assert what in ('valid', 'raw')
228 if what == 'valid':
229 trace.snuffle(self.get_valid_traces())
230 else:
231 trace.snuffle(list(self))
233 def bleed(self, t):
235 nt = int(num.round(abs(t)/self.deltat))
236 if nt < 1:
237 return
239 if self.data.mask is ma.nomask:
240 self.data.mask = ma.make_mask_none(self.data.shape)
242 for irow, row in enumerate(ma.notmasked_contiguous(self.data, axis=1)):
243 for span in row:
244 self.data.mask[irow, span.start:span.start+nt] = True
245 self.data.mask[irow, max(0, span.stop-nt):span.stop] = True
247 self.data.mask[:, :nt] = True
248 self.data.mask[:, -nt:] = True
250 def set_data(self, data):
251 if data is self.data:
252 return
254 assert data.shape == self.data.shape
256 if isinstance(data, ma.MaskedArray):
257 self.data = data
258 else:
259 data = ma.MaskedArray(data)
260 data.mask = self.data.mask
261 self.data = data
263 def apply(self, f):
264 self.set_data(f(self.data))
266 def reduce(self, f, codes):
267 data = f(self.data)
268 if data.ndim == 1:
269 data = data[num.newaxis, :]
270 if isinstance(codes, CodesNSLCE):
271 codes = [codes]
272 assert data.ndim == 2
273 assert data.shape[1] == self.data.shape[1]
274 assert len(codes) == data.shape[0]
275 self.codes = codes
276 if isinstance(data, ma.MaskedArray):
277 self.data = data
278 else:
279 self.data = ma.MaskedArray(data)
281 def nyquist_check(
282 self,
283 frequency,
284 intro='Corner frequency',
285 warn=True,
286 raise_exception=False):
288 '''
289 Check if a given frequency is above the Nyquist frequency of the trace.
291 :param intro:
292 String used to introduce the warning/error message.
293 :type intro:
294 str
296 :param warn:
297 Whether to emit a warning message.
298 :type warn:
299 bool
301 :param raise_exception:
302 Whether to raise :py:exc:`~pyrocko.trace.AboveNyquist`.
303 :type raise_exception:
304 bool
305 '''
307 if frequency >= 0.5/self.deltat:
308 message = '%s (%g Hz) is equal to or higher than nyquist ' \
309 'frequency (%g Hz). (%s)' \
310 % (intro, frequency, 0.5/self.deltat, self.summary)
311 if warn:
312 logger.warning(message)
313 if raise_exception:
314 raise AboveNyquist(message)
316 def lfilter(self, b, a, demean=True):
317 '''
318 Filter waveforms with :py:func:`scipy.signal.lfilter`.
320 Sample data is converted to type :py:class:`float`, possibly demeaned
321 and filtered using :py:func:`scipy.signal.lfilter`.
323 :param b:
324 Numerator coefficients.
325 :type b:
326 float
328 :param a:
329 Denominator coefficients.
330 :type a:
331 float
333 :param demean:
334 Subtract mean before filttering.
335 :type demean:
336 bool
337 '''
339 def filt(data):
340 data = data.astype(num.float64)
341 if demean:
342 data -= num.mean(data, axis=1)[:, num.newaxis]
344 return signal.lfilter(b, a, data)
346 self.apply(filt)
348 def lowpass(self, order, corner, nyquist_warn=True,
349 nyquist_exception=False, demean=True):
351 '''
352 Filter waveforms using a Butterworth lowpass.
354 Sample data is converted to type :py:class:`float`, possibly demeaned
355 and filtered using :py:func:`scipy.signal.lfilter`. Filter coefficients
356 are generated with :py:func:`scipy.signal.butter`.
358 :param order:
359 Order of the filter.
360 :type order:
361 int
363 :param corner:
364 Corner frequency of the filter [Hz].
365 :type corner:
366 float
368 :param demean:
369 Subtract mean before filtering.
370 :type demean:
371 bool
373 :param nyquist_warn:
374 Warn if corner frequency is greater than Nyquist frequency.
375 :type nyquist_warn:
376 bool
378 :param nyquist_exception:
379 Raise :py:exc:`pyrocko.trace.AboveNyquist` if corner frequency is
380 greater than Nyquist frequency.
381 :type nyquist_exception:
382 bool
383 '''
385 self.nyquist_check(
386 corner, 'Corner frequency of lowpass', nyquist_warn,
387 nyquist_exception)
389 (b, a) = _get_cached_filter_coeffs(
390 order, [corner*2.0*self.deltat], btype='low')
392 if len(a) != order+1 or len(b) != order+1:
393 logger.warning(
394 'Erroneous filter coefficients returned by '
395 'scipy.signal.butter(). Should downsample before filtering.')
397 self.lfilter(b, a, demean=demean)
399 def highpass(self, order, corner, nyquist_warn=True,
400 nyquist_exception=False, demean=True):
402 '''
403 Filter waveforms using a Butterworth highpass.
405 Sample data is converted to type :py:class:`float`, possibly demeaned
406 and filtered using :py:func:`scipy.signal.lfilter`. Filter coefficients
407 are generated with :py:func:`scipy.signal.butter`.
409 :param order:
410 Order of the filter.
411 :type order:
412 int
414 :param corner:
415 Corner frequency of the filter [Hz].
416 :type corner:
417 float
419 :param demean:
420 Subtract mean before filtering.
421 :type demean:
422 bool
424 :param nyquist_warn:
425 Warn if corner frequency is greater than Nyquist frequency.
426 :type nyquist_warn:
427 bool
429 :param nyquist_exception:
430 Raise :py:exc:`~pyrocko.trace.AboveNyquist` if corner frequency is
431 greater than Nyquist frequency.
432 :type nyquist_exception:
433 bool
434 '''
436 self.nyquist_check(
437 corner, 'Corner frequency of highpass', nyquist_warn,
438 nyquist_exception)
440 (b, a) = _get_cached_filter_coeffs(
441 order, [corner*2.0*self.deltat], btype='high')
443 if len(a) != order+1 or len(b) != order+1:
444 logger.warning(
445 'Erroneous filter coefficients returned by '
446 'scipy.signal.butter(). Should downsample before filtering.')
448 self.lfilter(b, a, demean=demean)
450 def smooth(self, t, window=num.hanning):
451 n = (int(num.round(t / self.deltat)) // 2) * 2 + 1
452 taper = num.hanning(n)
454 def multiply_taper(df, ntrans, spec):
455 taper_pad = num.zeros(ntrans)
456 taper_pad[:n//2+1] = taper[n//2:]
457 taper_pad[-n//2+1:] = taper[:n//2]
458 taper_fd = num.fft.rfft(taper_pad)
459 spec *= taper_fd[num.newaxis, :]
460 return spec
462 self.apply_via_fft(
463 multiply_taper,
464 ntrans_min=n)
466 def apply_via_fft(self, f, ntrans_min=0):
467 ntrans = trace.nextpow2(max(ntrans_min, self.nsamples))
468 data = ma.filled(self.data.astype(num.float64), 0.0)
469 spec = num.fft.rfft(data, ntrans)
470 df = 1.0 / (self.deltat * ntrans)
471 spec = f(df, ntrans, spec)
472 data2 = num.fft.irfft(spec)[:, :self.nsamples]
473 self.set_data(data2)
475 def get_energy(
476 self,
477 grouping=SensorGrouping(),
478 translation=ReplaceComponentTranslation(),
479 postprocessing=None):
481 groups = {}
482 for irow, codes in enumerate(self.codes):
483 k = grouping.key(codes)
484 if k not in groups:
485 groups[k] = []
487 groups[k].append(irow)
489 data = self.data.astype(num.float64)
490 data **= 2
491 data3 = num.ma.empty((len(groups), self.nsamples))
492 codes = []
493 for irow_out, irows_in in enumerate(groups.values()):
494 data3[irow_out, :] = data[irows_in, :].sum(axis=0)
495 codes.append(CodesNSLCE(
496 translation.translate(
497 self.codes[irows_in[0]]).safe_str.format(component='G')))
499 energy = MultiTrace(
500 data=data3,
501 codes=codes,
502 tmin=self.tmin,
503 deltat=self.deltat)
505 if postprocessing is not None:
506 energy.apply(postprocessing)
508 return energy
510 get_rms = partialmethod(
511 get_energy,
512 postprocessing=lambda data: num.sqrt(data, out=data))
514 get_log_rms = partialmethod(
515 get_energy,
516 postprocessing=lambda data: num.multiply(
517 num.log(
518 signal.filtfilt([0.5, 0.5], [1], data),
519 out=data),
520 0.5,
521 out=data))
523 get_log10_rms = partialmethod(
524 get_energy,
525 postprocessing=lambda data: num.multiply(
526 num.log(
527 signal.filtfilt([0.5, 0.5], [1], data),
528 out=data),
529 0.5 / num.log(10.0),
530 out=data))
533def correlate(a, b, mode='valid', normalization=None, use_fft=False):
535 if isinstance(a, Trace) and isinstance(b, Trace):
536 return trace.correlate(
537 a, b, mode=mode, normalization=normalization, use_fft=use_fft)
539 elif isinstance(a, Trace) and isinstance(b, MultiTrace):
540 return MultiTrace([
541 trace.correlate(
542 a, b_,
543 mode=mode, normalization=normalization, use_fft=use_fft)
544 for b_ in b])
546 elif isinstance(a, MultiTrace) and isinstance(b, Trace):
547 return MultiTrace([
548 trace.correlate(
549 a_, b,
550 mode=mode, normalization=normalization, use_fft=use_fft)
551 for a_ in a])
553 elif isinstance(a, MultiTrace) and isinstance(b, MultiTrace):
554 return MultiTrace([
555 trace.correlate(
556 a_, b_,
557 mode=mode, normalization=normalization, use_fft=use_fft)
559 for a_ in a for b_ in b])