Coverage for /usr/local/lib/python3.11/dist-packages/pyrocko/multitrace.py: 34%

188 statements  

« prev     ^ index     » next       coverage.py v6.5.0, created at 2024-02-27 10:58 +0000

1# http://pyrocko.org - GPLv3 

2# 

3# The Pyrocko Developers, 21st Century 

4# ---|P------/S----------~Lg---------- 

5 

6''' 

7Multi-component waveform data model. 

8''' 

9 

10import logging 

11from functools import partialmethod 

12 

13import numpy as num 

14import numpy.ma as ma 

15from scipy import signal 

16 

17from . import trace, util 

18from .trace import Trace, AboveNyquist, _get_cached_filter_coeffs 

19from .guts import Object, Float, Timestamp, List 

20from .guts_array import Array 

21from .squirrel import \ 

22 CodesNSLCE, SensorGrouping 

23 

24from .squirrel.operators.base import ReplaceComponentTranslation 

25 

26logger = logging.getLogger('pyrocko.multitrace') 

27 

28 

29class MultiTrace(Object): 

30 ''' 

31 Container for multi-component waveforms with common time span and sampling. 

32 

33 Instances of this class can be used to efficiently represent 

34 multi-component waveforms of a single sensor or of a sensor array. The data 

35 samples are stored in a single 2D array where the first index runs over 

36 components and the second index over time. Metadata contains sampling rate, 

37 start-time and :py:class:`~pyrocko.squirrel.model.CodesNSLCE` identifiers 

38 for the contained traces. 

39 

40 The :py:gattr:`data` is held as a NumPy :py:class:`numpy.ma.MaskedArray` 

41 where missing or invalid data is masked. 

42 

43 :param traces: 

44 If given, construct multi-trace from given single-component waveforms 

45 (see :py:func:`~pyrocko.trace.get_traces_data_as_array`) and ignore 

46 any other arguments. 

47 :type traces: 

48 :py:class:`list` of :py:class:`~pyrocko.trace.Trace` 

49 ''' 

50 

51 codes = List.T( 

52 CodesNSLCE.T(), 

53 help='List of codes identifying the components.') 

54 data = Array.T( 

55 shape=(None, None), 

56 help='Array containing the data samples indexed as ' 

57 '``(icomponent, isample)``.') 

58 tmin = Timestamp.T( 

59 default=Timestamp.D('1970-01-01 00:00:00'), 

60 help='Start time.') 

61 deltat = Float.T( 

62 default=1.0, 

63 help='Sampling interval [s]') 

64 

65 def __init__( 

66 self, 

67 traces=None, 

68 assemble='concatenate', 

69 data=None, 

70 codes=None, 

71 tmin=None, 

72 deltat=None): 

73 

74 if traces is not None: 

75 if len(traces) == 0: 

76 data = ma.zeros((0, 0)) 

77 else: 

78 if assemble == 'merge': 

79 data, codes, tmin, deltat \ 

80 = trace.merge_traces_data_as_array(traces) 

81 

82 elif assemble == 'concatenate': 

83 data = ma.array(trace.get_traces_data_as_array(traces)) 

84 codes = [tr.codes for tr in traces] 

85 tmin = traces[0].tmin 

86 deltat = traces[0].deltat 

87 

88 self.ntraces, self.nsamples = data.shape 

89 

90 if codes is None: 

91 codes = [CodesNSLCE()] * self.ntraces 

92 

93 if len(codes) != self.ntraces: 

94 raise ValueError( 

95 'MultiTrace construction: mismatch between number of traces ' 

96 'and number of codes given.') 

97 

98 if deltat is None: 

99 deltat = self.T.deltat.default() 

100 

101 if tmin is None: 

102 tmin = self.T.tmin.default() 

103 

104 Object.__init__(self, codes=codes, data=data, tmin=tmin, deltat=deltat) 

105 

106 @property 

107 def summary_codes(self): 

108 if self.codes: 

109 if len(self.codes) == 1: 

110 return str(self.codes[0]) 

111 elif len(self.codes) == 2: 

112 return '%s, %s' % (self.codes[0], self.codes[-1]) 

113 else: 

114 return '%s, ..., %s' % (self.codes[0], self.codes[-1]) 

115 else: 

116 return 'None' 

117 

118 @property 

119 def summary_entries(self): 

120 return ( 

121 self.__class__.__name__, 

122 str(self.data.shape[0]), 

123 str(self.data.shape[1]), 

124 str(self.data.dtype), 

125 str(self.deltat), 

126 util.time_to_str(self.tmin), 

127 util.time_to_str(self.tmax), 

128 self.summary_codes) 

129 

130 @property 

131 def summary(self): 

132 ''' 

133 Textual summary of the waveform's metadata attributes. 

134 ''' 

135 return util.fmt_summary( 

136 self.summary_entries, (10, 5, 7, 10, 10, 25, 25, 50)) 

137 

138 def __len__(self): 

139 ''' 

140 Get number of components. 

141 ''' 

142 return self.ntraces 

143 

144 def __getitem__(self, i): 

145 ''' 

146 Get single component waveform (shared data). 

147 

148 :param i: 

149 Component index. 

150 :type i: 

151 int 

152 ''' 

153 return self.get_trace(i) 

154 

155 def copy(self, data='copy'): 

156 ''' 

157 Create a copy 

158 

159 :param data: 

160 ``'copy'`` to deeply copy the data, or ``'reference'`` to create 

161 a shallow copy, referencing the original data. 

162 :type data: 

163 str 

164 ''' 

165 

166 if isinstance(data, str): 

167 assert data in ('copy', 'reference') 

168 data = self.data.copy() if data == 'copy' else self.data 

169 else: 

170 assert isinstance(data, ma.MaskedArray) 

171 

172 return MultiTrace( 

173 data=data, 

174 codes=list(self.codes), 

175 tmin=self.tmin, 

176 deltat=self.deltat) 

177 

178 @property 

179 def tmax(self): 

180 ''' 

181 End time (time of last sample, read-only). 

182 ''' 

183 return self.tmin + (self.nsamples - 1) * self.deltat 

184 

185 def get_trace(self, i, span=slice(None)): 

186 ''' 

187 Get single component waveform (shared data). 

188 

189 :param i: 

190 Component index. 

191 :type i: 

192 int 

193 ''' 

194 

195 network, station, location, channel, extra = self.codes[i] 

196 return Trace( 

197 network=network, 

198 station=station, 

199 location=location, 

200 channel=channel, 

201 extra=extra, 

202 tmin=self.tmin + (span.start or 0) * self.deltat, 

203 deltat=self.deltat, 

204 ydata=self.data.data[i, span]) 

205 

206 def iter_valid_traces(self): 

207 if self.data.mask is ma.nomask: 

208 yield from self 

209 else: 

210 for irow, row in enumerate( 

211 ma.notmasked_contiguous(self.data, axis=1)): 

212 for slice in row: 

213 yield self.get_trace(irow, slice) 

214 

215 def get_traces(self): 

216 return list(self) 

217 

218 def get_valid_traces(self): 

219 return list(self.iter_valid_traces()) 

220 

221 def snuffle(self, what='valid'): 

222 ''' 

223 Show in Snuffler. 

224 ''' 

225 

226 assert what in ('valid', 'raw') 

227 

228 if what == 'valid': 

229 trace.snuffle(self.get_valid_traces()) 

230 else: 

231 trace.snuffle(list(self)) 

232 

233 def bleed(self, t): 

234 

235 nt = int(num.round(abs(t)/self.deltat)) 

236 if nt < 1: 

237 return 

238 

239 if self.data.mask is ma.nomask: 

240 self.data.mask = ma.make_mask_none(self.data.shape) 

241 

242 for irow, row in enumerate(ma.notmasked_contiguous(self.data, axis=1)): 

243 for span in row: 

244 self.data.mask[irow, span.start:span.start+nt] = True 

245 self.data.mask[irow, max(0, span.stop-nt):span.stop] = True 

246 

247 self.data.mask[:, :nt] = True 

248 self.data.mask[:, -nt:] = True 

249 

250 def set_data(self, data): 

251 if data is self.data: 

252 return 

253 

254 assert data.shape == self.data.shape 

255 

256 if isinstance(data, ma.MaskedArray): 

257 self.data = data 

258 else: 

259 data = ma.MaskedArray(data) 

260 data.mask = self.data.mask 

261 self.data = data 

262 

263 def apply(self, f): 

264 self.set_data(f(self.data)) 

265 

266 def reduce(self, f, codes): 

267 data = f(self.data) 

268 if data.ndim == 1: 

269 data = data[num.newaxis, :] 

270 if isinstance(codes, CodesNSLCE): 

271 codes = [codes] 

272 assert data.ndim == 2 

273 assert data.shape[1] == self.data.shape[1] 

274 assert len(codes) == data.shape[0] 

275 self.codes = codes 

276 if isinstance(data, ma.MaskedArray): 

277 self.data = data 

278 else: 

279 self.data = ma.MaskedArray(data) 

280 

281 def nyquist_check( 

282 self, 

283 frequency, 

284 intro='Corner frequency', 

285 warn=True, 

286 raise_exception=False): 

287 

288 ''' 

289 Check if a given frequency is above the Nyquist frequency of the trace. 

290 

291 :param intro: 

292 String used to introduce the warning/error message. 

293 :type intro: 

294 str 

295 

296 :param warn: 

297 Whether to emit a warning message. 

298 :type warn: 

299 bool 

300 

301 :param raise_exception: 

302 Whether to raise :py:exc:`~pyrocko.trace.AboveNyquist`. 

303 :type raise_exception: 

304 bool 

305 ''' 

306 

307 if frequency >= 0.5/self.deltat: 

308 message = '%s (%g Hz) is equal to or higher than nyquist ' \ 

309 'frequency (%g Hz). (%s)' \ 

310 % (intro, frequency, 0.5/self.deltat, self.summary) 

311 if warn: 

312 logger.warning(message) 

313 if raise_exception: 

314 raise AboveNyquist(message) 

315 

316 def lfilter(self, b, a, demean=True): 

317 ''' 

318 Filter waveforms with :py:func:`scipy.signal.lfilter`. 

319 

320 Sample data is converted to type :py:class:`float`, possibly demeaned 

321 and filtered using :py:func:`scipy.signal.lfilter`. 

322 

323 :param b: 

324 Numerator coefficients. 

325 :type b: 

326 float 

327 

328 :param a: 

329 Denominator coefficients. 

330 :type a: 

331 float 

332 

333 :param demean: 

334 Subtract mean before filttering. 

335 :type demean: 

336 bool 

337 ''' 

338 

339 def filt(data): 

340 data = data.astype(num.float64) 

341 if demean: 

342 data -= num.mean(data, axis=1)[:, num.newaxis] 

343 

344 return signal.lfilter(b, a, data) 

345 

346 self.apply(filt) 

347 

348 def lowpass(self, order, corner, nyquist_warn=True, 

349 nyquist_exception=False, demean=True): 

350 

351 ''' 

352 Filter waveforms using a Butterworth lowpass. 

353 

354 Sample data is converted to type :py:class:`float`, possibly demeaned 

355 and filtered using :py:func:`scipy.signal.lfilter`. Filter coefficients 

356 are generated with :py:func:`scipy.signal.butter`. 

357 

358 :param order: 

359 Order of the filter. 

360 :type order: 

361 int 

362 

363 :param corner: 

364 Corner frequency of the filter [Hz]. 

365 :type corner: 

366 float 

367 

368 :param demean: 

369 Subtract mean before filtering. 

370 :type demean: 

371 bool 

372 

373 :param nyquist_warn: 

374 Warn if corner frequency is greater than Nyquist frequency. 

375 :type nyquist_warn: 

376 bool 

377 

378 :param nyquist_exception: 

379 Raise :py:exc:`pyrocko.trace.AboveNyquist` if corner frequency is 

380 greater than Nyquist frequency. 

381 :type nyquist_exception: 

382 bool 

383 ''' 

384 

385 self.nyquist_check( 

386 corner, 'Corner frequency of lowpass', nyquist_warn, 

387 nyquist_exception) 

388 

389 (b, a) = _get_cached_filter_coeffs( 

390 order, [corner*2.0*self.deltat], btype='low') 

391 

392 if len(a) != order+1 or len(b) != order+1: 

393 logger.warning( 

394 'Erroneous filter coefficients returned by ' 

395 'scipy.signal.butter(). Should downsample before filtering.') 

396 

397 self.lfilter(b, a, demean=demean) 

398 

399 def highpass(self, order, corner, nyquist_warn=True, 

400 nyquist_exception=False, demean=True): 

401 

402 ''' 

403 Filter waveforms using a Butterworth highpass. 

404 

405 Sample data is converted to type :py:class:`float`, possibly demeaned 

406 and filtered using :py:func:`scipy.signal.lfilter`. Filter coefficients 

407 are generated with :py:func:`scipy.signal.butter`. 

408 

409 :param order: 

410 Order of the filter. 

411 :type order: 

412 int 

413 

414 :param corner: 

415 Corner frequency of the filter [Hz]. 

416 :type corner: 

417 float 

418 

419 :param demean: 

420 Subtract mean before filtering. 

421 :type demean: 

422 bool 

423 

424 :param nyquist_warn: 

425 Warn if corner frequency is greater than Nyquist frequency. 

426 :type nyquist_warn: 

427 bool 

428 

429 :param nyquist_exception: 

430 Raise :py:exc:`~pyrocko.trace.AboveNyquist` if corner frequency is 

431 greater than Nyquist frequency. 

432 :type nyquist_exception: 

433 bool 

434 ''' 

435 

436 self.nyquist_check( 

437 corner, 'Corner frequency of highpass', nyquist_warn, 

438 nyquist_exception) 

439 

440 (b, a) = _get_cached_filter_coeffs( 

441 order, [corner*2.0*self.deltat], btype='high') 

442 

443 if len(a) != order+1 or len(b) != order+1: 

444 logger.warning( 

445 'Erroneous filter coefficients returned by ' 

446 'scipy.signal.butter(). Should downsample before filtering.') 

447 

448 self.lfilter(b, a, demean=demean) 

449 

450 def smooth(self, t, window=num.hanning): 

451 n = (int(num.round(t / self.deltat)) // 2) * 2 + 1 

452 taper = num.hanning(n) 

453 

454 def multiply_taper(df, ntrans, spec): 

455 taper_pad = num.zeros(ntrans) 

456 taper_pad[:n//2+1] = taper[n//2:] 

457 taper_pad[-n//2+1:] = taper[:n//2] 

458 taper_fd = num.fft.rfft(taper_pad) 

459 spec *= taper_fd[num.newaxis, :] 

460 return spec 

461 

462 self.apply_via_fft( 

463 multiply_taper, 

464 ntrans_min=n) 

465 

466 def apply_via_fft(self, f, ntrans_min=0): 

467 ntrans = trace.nextpow2(max(ntrans_min, self.nsamples)) 

468 data = ma.filled(self.data.astype(num.float64), 0.0) 

469 spec = num.fft.rfft(data, ntrans) 

470 df = 1.0 / (self.deltat * ntrans) 

471 spec = f(df, ntrans, spec) 

472 data2 = num.fft.irfft(spec)[:, :self.nsamples] 

473 self.set_data(data2) 

474 

475 def get_energy( 

476 self, 

477 grouping=SensorGrouping(), 

478 translation=ReplaceComponentTranslation(), 

479 postprocessing=None): 

480 

481 groups = {} 

482 for irow, codes in enumerate(self.codes): 

483 k = grouping.key(codes) 

484 if k not in groups: 

485 groups[k] = [] 

486 

487 groups[k].append(irow) 

488 

489 data = self.data.astype(num.float64) 

490 data **= 2 

491 data3 = num.ma.empty((len(groups), self.nsamples)) 

492 codes = [] 

493 for irow_out, irows_in in enumerate(groups.values()): 

494 data3[irow_out, :] = data[irows_in, :].sum(axis=0) 

495 codes.append(CodesNSLCE( 

496 translation.translate( 

497 self.codes[irows_in[0]]).safe_str.format(component='G'))) 

498 

499 energy = MultiTrace( 

500 data=data3, 

501 codes=codes, 

502 tmin=self.tmin, 

503 deltat=self.deltat) 

504 

505 if postprocessing is not None: 

506 energy.apply(postprocessing) 

507 

508 return energy 

509 

510 get_rms = partialmethod( 

511 get_energy, 

512 postprocessing=lambda data: num.sqrt(data, out=data)) 

513 

514 get_log_rms = partialmethod( 

515 get_energy, 

516 postprocessing=lambda data: num.multiply( 

517 num.log( 

518 signal.filtfilt([0.5, 0.5], [1], data), 

519 out=data), 

520 0.5, 

521 out=data)) 

522 

523 get_log10_rms = partialmethod( 

524 get_energy, 

525 postprocessing=lambda data: num.multiply( 

526 num.log( 

527 signal.filtfilt([0.5, 0.5], [1], data), 

528 out=data), 

529 0.5 / num.log(10.0), 

530 out=data)) 

531 

532 

533def correlate(a, b, mode='valid', normalization=None, use_fft=False): 

534 

535 if isinstance(a, Trace) and isinstance(b, Trace): 

536 return trace.correlate( 

537 a, b, mode=mode, normalization=normalization, use_fft=use_fft) 

538 

539 elif isinstance(a, Trace) and isinstance(b, MultiTrace): 

540 return MultiTrace([ 

541 trace.correlate( 

542 a, b_, 

543 mode=mode, normalization=normalization, use_fft=use_fft) 

544 for b_ in b]) 

545 

546 elif isinstance(a, MultiTrace) and isinstance(b, Trace): 

547 return MultiTrace([ 

548 trace.correlate( 

549 a_, b, 

550 mode=mode, normalization=normalization, use_fft=use_fft) 

551 for a_ in a]) 

552 

553 elif isinstance(a, MultiTrace) and isinstance(b, MultiTrace): 

554 return MultiTrace([ 

555 trace.correlate( 

556 a_, b_, 

557 mode=mode, normalization=normalization, use_fft=use_fft) 

558 

559 for a_ in a for b_ in b])