Coverage for /usr/local/lib/python3.11/dist-packages/pyrocko/response.py: 70%
447 statements
« prev ^ index » next coverage.py v6.5.0, created at 2023-10-06 06:59 +0000
« prev ^ index » next coverage.py v6.5.0, created at 2023-10-06 06:59 +0000
1# http://pyrocko.org - GPLv3
2#
3# The Pyrocko Developers, 21st Century
4# ---|P------/S----------~Lg----------
6'''
7Frequency response parameterizations useful as transfer functions in signal
8processing.
9'''
11import math
12import logging
13import uuid
15import numpy as num
16from scipy import signal
18from pyrocko import util, evalresp
19from pyrocko.guts import Object, Float, Int, String, Complex, Tuple, List, \
20 StringChoice, Bool
21from pyrocko.guts_array import Array
24guts_prefix = 'pf'
26logger = logging.getLogger('pyrocko.response')
29def asarray_1d(x, dtype):
30 if isinstance(x, (list, tuple)) and x and isinstance(x[0], str):
31 return num.asarray(list(map(dtype, x)), dtype=dtype)
32 else:
33 a = num.asarray(x, dtype=dtype)
34 if not a.ndim == 1:
35 raise ValueError('Could not convert to 1D array.')
36 return a
39def finalize_construction(breakpoints):
40 breakpoints.sort()
41 breakpoints_out = []
42 f_last = None
43 for f, c in breakpoints:
44 if f_last is not None and f == f_last:
45 breakpoints_out[-1][1] += c
46 else:
47 breakpoints_out.append([f, c])
49 f_last = f
51 breakpoints_out = [(f, c) for (f, c) in breakpoints_out if c != 0]
52 return breakpoints_out
55class FrequencyResponseCheckpoint(Object):
56 frequency = Float.T()
57 value = Float.T()
60class IsNotScalar(Exception):
61 pass
64def str_fmax_failsafe(resp):
65 try:
66 return '%g' % resp.get_fmax()
67 except InvalidResponseError:
68 return '?'
71class FrequencyResponse(Object):
72 '''
73 Base class for parameterized frequency responses.
74 '''
76 checkpoints = List.T(
77 FrequencyResponseCheckpoint.T())
79 def __init__(self, *args, **kwargs):
80 Object.__init__(self, *args, **kwargs)
81 self.uuid = uuid.uuid4()
83 def evaluate(self, freqs):
84 '''
85 Evaluate the response at given frequencies.
87 :param freqs:
88 Frequencies [Hz].
89 :type freqs:
90 :py:class:`numpy.ndarray` of shape ``(N,)`` and dtype
91 :py:class:`float`
93 :returns:
94 Complex coefficients of the response.
95 :rtype:
96 :py:class:`numpy.ndarray` of shape ``(N,)`` and dtype
97 :py:class:`complex`
98 '''
99 return num.ones(freqs.size, dtype=complex)
101 def evaluate1(self, freq):
102 '''
103 Evaluate the response at a single frequency.
105 :param freq:
106 Frequency [Hz].
107 :type freqs:
108 float
110 :returns:
111 Complex response coefficient.
112 :rtype:
113 complex
114 '''
115 return self.evaluate(num.atleast_1d(freq))[0]
117 def is_scalar(self):
118 '''
119 Check if this is a flat response.
120 '''
122 if type(self) is FrequencyResponse:
123 return True
124 else:
125 return False # default for derived classes
127 def get_scalar(self):
128 '''
129 Get factor if this is a flat response.
130 '''
131 if type(self) is FrequencyResponse:
132 return 1.0
133 else:
134 raise IsNotScalar() # default for derived classes
136 def get_fmax(self):
137 '''
138 Get maximum frequency for which the response is defined.
140 :returns:
141 ``None`` if the response has no upper limit, otherwise the maximum
142 frequency in [Hz] for which the response is valid is returned.
143 :rtype:
144 float or None
145 '''
146 return None
148 def construction(self):
149 return []
151 @property
152 def summary(self):
153 '''
154 Short summary with key information about the response object.
155 '''
156 if type(self) is FrequencyResponse:
157 return 'one'
158 else:
159 return 'unknown'
162def str_gain(gain):
163 if gain == 1.0:
164 return 'one'
165 elif isinstance(gain, complex):
166 return 'gain{%s}' % repr(gain)
167 else:
168 return 'gain{%g}' % gain
171class Gain(FrequencyResponse):
172 '''
173 A flat frequency response.
174 '''
176 constant = Complex.T(default=1.0+0j)
178 def evaluate(self, freqs):
179 return util.num_full_like(freqs, self.constant, dtype=complex)
181 def is_scalar(self):
182 return True
184 def get_scalar(self):
185 return self.constant
187 @property
188 def summary(self):
189 return str_gain(self.constant)
192class Evalresp(FrequencyResponse):
193 '''
194 Calls evalresp and generates values of the instrument response transfer
195 function.
197 :param respfile: response file in evalresp format
198 :param trace: trace for which the response is to be extracted from the file
199 :param target: ``'dis'`` for displacement or ``'vel'`` for velocity
200 '''
202 respfile = String.T()
203 nslc_id = Tuple.T(4, String.T())
204 target = String.T(default='dis')
205 instant = Float.T()
206 stages = Tuple.T(2, Int.T(), optional=True)
208 def __init__(
209 self,
210 respfile,
211 trace=None,
212 target='dis',
213 nslc_id=None,
214 time=None,
215 stages=None,
216 **kwargs):
218 if trace is not None:
219 nslc_id = trace.nslc_id
220 time = (trace.tmin + trace.tmax) / 2.
222 FrequencyResponse.__init__(
223 self,
224 respfile=respfile,
225 nslc_id=nslc_id,
226 instant=time,
227 target=target,
228 stages=stages,
229 **kwargs)
231 def evaluate(self, freqs):
232 network, station, location, channel = self.nslc_id
233 if self.stages is None:
234 stages = (-1, 0)
235 else:
236 stages = self.stages[0]+1, self.stages[1]
238 x = evalresp.evalresp(
239 sta_list=station,
240 cha_list=channel,
241 net_code=network,
242 locid=location,
243 instant=self.instant,
244 freqs=freqs,
245 units=self.target.upper(),
246 file=self.respfile,
247 start_stage=stages[0],
248 stop_stage=stages[1],
249 rtype='CS')
251 transfer = x[0][4]
252 return transfer
254 @property
255 def summary(self):
256 return 'eresp'
259class InverseEvalresp(FrequencyResponse):
260 '''
261 Calls evalresp and generates values of the inverse instrument response for
262 deconvolution of instrument response.
264 :param respfile: response file in evalresp format
265 :param trace: trace for which the response is to be extracted from the file
266 :param target: ``'dis'`` for displacement or ``'vel'`` for velocity
267 '''
269 respfile = String.T()
270 nslc_id = Tuple.T(4, String.T())
271 target = String.T(default='dis')
272 instant = Float.T()
274 def __init__(self, respfile, trace, target='dis', **kwargs):
275 FrequencyResponse.__init__(
276 self,
277 respfile=respfile,
278 nslc_id=trace.nslc_id,
279 instant=(trace.tmin + trace.tmax)/2.,
280 target=target,
281 **kwargs)
283 def evaluate(self, freqs):
284 network, station, location, channel = self.nslc_id
285 x = evalresp.evalresp(sta_list=station,
286 cha_list=channel,
287 net_code=network,
288 locid=location,
289 instant=self.instant,
290 freqs=freqs,
291 units=self.target.upper(),
292 file=self.respfile,
293 rtype='CS')
295 transfer = x[0][4]
296 return 1./transfer
298 @property
299 def summary(self):
300 return 'inv_eresp'
303def aslist(x):
304 if x is None:
305 return []
307 try:
308 return list(x)
309 except TypeError:
310 return [x]
313class PoleZeroResponse(FrequencyResponse):
314 '''
315 Evaluates frequency response from pole-zero representation.
317 :param zeros: positions of zeros
318 :type zeros: :py:class:`list` of :py:class:`complex`
319 :param poles: positions of poles
320 :type poles: :py:class:`list` of :py:class:`complex`
321 :param constant: gain factor
322 :type constant: complex
324 ::
326 (j*2*pi*f - zeros[0]) * (j*2*pi*f - zeros[1]) * ...
327 T(f) = constant * ----------------------------------------------------
328 (j*2*pi*f - poles[0]) * (j*2*pi*f - poles[1]) * ...
331 The poles and zeros should be given as angular frequencies, not in Hz.
332 '''
334 zeros = List.T(Complex.T())
335 poles = List.T(Complex.T())
336 constant = Complex.T(default=1.0+0j)
338 def __init__(
339 self,
340 zeros=None,
341 poles=None,
342 constant=1.0+0j,
343 **kwargs):
345 if zeros is None:
346 zeros = []
347 if poles is None:
348 poles = []
350 FrequencyResponse.__init__(
351 self,
352 zeros=aslist(zeros),
353 poles=aslist(poles),
354 constant=constant,
355 **kwargs)
357 def evaluate(self, freqs):
358 if hasattr(signal, 'freqs_zpk'): # added in scipy 0.19.0
359 return signal.freqs_zpk(
360 self.zeros, self.poles, self.constant, freqs*2.*num.pi)[1]
361 else:
362 jomeg = 1.0j * 2.*num.pi*freqs
364 a = num.ones(freqs.size, dtype=complex)*self.constant
365 for z in self.zeros:
366 a *= jomeg-z
367 for p in self.poles:
368 a /= jomeg-p
370 return a
372 def is_scalar(self):
373 return len(self.zeros) == 0 and len(self.poles) == 0
375 def get_scalar(self):
376 '''
377 Get factor if this is a flat response.
378 '''
379 if self.is_scalar():
380 return self.constant
381 else:
382 raise IsNotScalar()
384 def inverse(self):
385 return PoleZeroResponse(
386 poles=list(self.zeros),
387 zeros=list(self.poles),
388 constant=1.0/self.constant)
390 def to_analog(self):
391 b, a = signal.zpk2tf(self.zeros, self.poles, self.constant)
392 return AnalogFilterResponse(aslist(b), aslist(a))
394 def to_digital(self, deltat, method='bilinear'):
395 from scipy.signal import cont2discrete, zpk2tf
397 z, p, k, _ = cont2discrete(
398 (self.zeros, self.poles, self.constant),
399 deltat, method=method)
401 b, a = zpk2tf(z, p, k)
403 return DigitalFilterResponse(b, a, deltat)
405 def to_digital_polezero(self, deltat, method='bilinear'):
406 from scipy.signal import cont2discrete
408 z, p, k, _ = cont2discrete(
409 (self.zeros, self.poles, self.constant),
410 deltat, method=method)
412 return DigitalPoleZeroResponse(z, p, k, deltat)
414 def construction(self):
415 breakpoints = []
416 for zero in self.zeros:
417 f = abs(zero) / (2.*math.pi)
418 breakpoints.append((f, 1))
420 for pole in self.poles:
421 f = abs(pole) / (2.*math.pi)
422 breakpoints.append((f, -1))
424 return finalize_construction(breakpoints)
426 @property
427 def summary(self):
428 if self.is_scalar():
429 return str_gain(self.get_scalar())
431 return 'pz{%i,%i}' % (len(self.poles), len(self.zeros))
434class DigitalPoleZeroResponse(FrequencyResponse):
435 '''
436 Evaluates frequency response from digital filter pole-zero representation.
438 :param zeros: positions of zeros
439 :type zeros: :py:class:`list` of :py:class:`complex`
440 :param poles: positions of poles
441 :type poles: :py:class:`list` of :py:class:`complex`
442 :param constant: gain factor
443 :type constant: complex
444 :param deltat: sampling interval
445 :type deltat: float
447 The poles and zeros should be given as angular frequencies, not in Hz.
448 '''
450 zeros = List.T(Complex.T())
451 poles = List.T(Complex.T())
452 constant = Complex.T(default=1.0+0j)
453 deltat = Float.T()
455 def __init__(
456 self,
457 zeros=None,
458 poles=None,
459 constant=1.0+0j,
460 deltat=None,
461 **kwargs):
463 if zeros is None:
464 zeros = []
465 if poles is None:
466 poles = []
467 if deltat is None:
468 raise ValueError(
469 'Sampling interval `deltat` must be given for '
470 'DigitalPoleZeroResponse.')
472 FrequencyResponse.__init__(
473 self, zeros=aslist(zeros), poles=aslist(poles), constant=constant,
474 deltat=deltat, **kwargs)
476 def check_sampling_rate(self):
477 if self.deltat == 0.0:
478 raise InvalidResponseError(
479 'Invalid digital response: sampling rate undefined.')
481 def get_fmax(self):
482 self.check_sampling_rate()
483 return 0.5 / self.deltat
485 def evaluate(self, freqs):
486 self.check_sampling_rate()
487 return signal.freqz_zpk(
488 self.zeros, self.poles, self.constant,
489 freqs*(2.*math.pi*self.deltat))[1]
491 def is_scalar(self):
492 return len(self.zeros) == 0 and len(self.poles) == 0
494 def get_scalar(self):
495 '''
496 Get factor if this is a flat response.
497 '''
498 if self.is_scalar():
499 return self.constant
500 else:
501 raise IsNotScalar()
503 def to_digital(self, deltat):
504 self.check_sampling_rate()
505 from scipy.signal import zpk2tf
507 b, a = zpk2tf(self.zeros, self.poles, self.constant)
508 return DigitalFilterResponse(b, a, deltat)
510 @property
511 def summary(self):
512 if self.is_scalar():
513 return str_gain(self.get_scalar())
515 return 'dpz{%i,%i,%s}' % (
516 len(self.poles), len(self.zeros), str_fmax_failsafe(self))
519class ButterworthResponse(FrequencyResponse):
520 '''
521 Butterworth frequency response.
523 :param corner: corner frequency of the response
524 :param order: order of the response
525 :param type: either ``high`` or ``low``
526 '''
528 corner = Float.T(default=1.0)
529 order = Int.T(default=4)
530 type = StringChoice.T(choices=['low', 'high'], default='low')
532 def to_polezero(self):
533 z, p, k = signal.butter(
534 self.order, self.corner*2.*math.pi,
535 btype=self.type, analog=True, output='zpk')
537 return PoleZeroResponse(
538 zeros=aslist(z),
539 poles=aslist(p),
540 constant=float(k))
542 def to_digital(self, deltat):
543 b, a = signal.butter(
544 self.order, self.corner*2.*deltat,
545 self.type, analog=False)
547 return DigitalFilterResponse(b, a, deltat)
549 def to_analog(self):
550 b, a = signal.butter(
551 self.order, self.corner*2.*math.pi,
552 self.type, analog=True)
554 return AnalogFilterResponse(b, a)
556 def to_digital_polezero(self, deltat):
557 z, p, k = signal.butter(
558 self.order, self.corner*2*deltat,
559 btype=self.type, analog=False, output='zpk')
561 return DigitalPoleZeroResponse(z, p, k, deltat)
563 def evaluate(self, freqs):
564 b, a = signal.butter(
565 self.order, self.corner*2.*math.pi,
566 self.type, analog=True)
568 return signal.freqs(b, a, freqs*2.*math.pi)[1]
570 @property
571 def summary(self):
572 return 'butter_%s{%i,%g}' % (
573 self.type,
574 self.order,
575 self.corner)
578class SampledResponse(FrequencyResponse):
579 '''
580 Interpolates frequency response given at a set of sampled frequencies.
582 :param frequencies,values: frequencies and values of the sampled response
583 function.
584 :param left,right: values to return when input is out of range. If set to
585 ``None`` (the default) the endpoints are returned.
586 '''
588 frequencies = Array.T(shape=(None,), dtype=float, serialize_as='list')
589 values = Array.T(shape=(None,), dtype=complex, serialize_as='list')
590 left = Complex.T(optional=True)
591 right = Complex.T(optional=True)
593 def __init__(self, frequencies, values, left=None, right=None, **kwargs):
594 FrequencyResponse.__init__(
595 self,
596 frequencies=asarray_1d(frequencies, float),
597 values=asarray_1d(values, complex),
598 **kwargs)
600 def evaluate(self, freqs):
601 ereal = num.interp(
602 freqs, self.frequencies, num.real(self.values),
603 left=self.left, right=self.right)
604 eimag = num.interp(
605 freqs, self.frequencies, num.imag(self.values),
606 left=self.left, right=self.right)
607 transfer = ereal + 1.0j*eimag
608 return transfer
610 def inverse(self):
611 '''
612 Get inverse as a new :py:class:`SampledResponse` object.
613 '''
615 def inv_or_none(x):
616 if x is not None:
617 return 1./x
619 return SampledResponse(
620 self.frequencies, 1./self.values,
621 left=inv_or_none(self.left),
622 right=inv_or_none(self.right))
624 @property
625 def summary(self):
626 return 'sampled'
629class IntegrationResponse(FrequencyResponse):
630 '''
631 The integration response, optionally multiplied by a constant gain.
633 :param n: exponent (integer)
634 :param gain: gain factor (float)
636 ::
638 gain
639 T(f) = --------------
640 (j*2*pi * f)^n
641 '''
643 n = Int.T(optional=True, default=1)
644 gain = Float.T(optional=True, default=1.0)
646 def __init__(self, n=1, gain=1.0, **kwargs):
647 FrequencyResponse.__init__(self, n=n, gain=gain, **kwargs)
649 def evaluate(self, freqs):
650 nonzero = freqs != 0.0
651 resp = num.zeros(freqs.size, dtype=complex)
652 resp[nonzero] = self.gain / (1.0j * 2. * num.pi*freqs[nonzero])**self.n
653 return resp
655 @property
656 def summary(self):
657 return 'integration{%i}' % self.n + (
658 '*gain{%g}' % self.gain
659 if self.gain is not None and self.gain != 1.0
660 else '')
663class DifferentiationResponse(FrequencyResponse):
664 '''
665 The differentiation response, optionally multiplied by a constant gain.
667 :param n: exponent (integer)
668 :param gain: gain factor (float)
670 ::
672 T(f) = gain * (j*2*pi * f)^n
673 '''
675 n = Int.T(optional=True, default=1)
676 gain = Float.T(optional=True, default=1.0)
678 def __init__(self, n=1, gain=1.0, **kwargs):
679 FrequencyResponse.__init__(self, n=n, gain=gain, **kwargs)
681 def evaluate(self, freqs):
682 return self.gain * (1.0j * 2. * num.pi * freqs)**self.n
684 @property
685 def summary(self):
686 return 'differentiation{%i}' % self.n + (
687 '*gain{%g}' % self.gain
688 if self.gain is not None and self.gain != 1.0
689 else '')
692class DigitalFilterResponse(FrequencyResponse):
693 '''
694 Frequency response of an analog filter.
696 (see :py:func:`scipy.signal.freqz`).
697 '''
699 b = List.T(Float.T())
700 a = List.T(Float.T())
701 deltat = Float.T()
702 drop_phase = Bool.T(default=False)
704 def __init__(self, b, a, deltat, drop_phase=False, **kwargs):
705 FrequencyResponse.__init__(
706 self, b=aslist(b), a=aslist(a), deltat=float(deltat),
707 drop_phase=drop_phase, **kwargs)
709 def check_sampling_rate(self):
710 if self.deltat == 0.0:
711 raise InvalidResponseError(
712 'Invalid digital response: sampling rate undefined.')
714 def is_scalar(self):
715 return len(self.a) == 1 and len(self.b) == 1
717 def get_scalar(self):
718 if self.is_scalar():
719 return self.b[0] / self.a[0]
720 else:
721 raise IsNotScalar()
723 def get_fmax(self):
724 if not self.is_scalar():
725 self.check_sampling_rate()
726 return 0.5 / self.deltat
727 else:
728 return None
730 def evaluate(self, freqs):
731 if self.is_scalar():
732 return util.num_full_like(freqs, self.get_scalar(), dtype=complex)
734 self.check_sampling_rate()
736 ok = freqs <= 0.5/self.deltat
737 coeffs = num.zeros(freqs.size, dtype=complex)
739 coeffs[ok] = signal.freqz(
740 self.b, self.a, freqs[ok]*2.*math.pi * self.deltat)[1]
742 coeffs[num.logical_not(ok)] = None
743 if self.drop_phase:
744 return num.abs(coeffs)
745 else:
746 return coeffs
748 def filter(self, tr):
749 self.check_sampling_rate()
751 from pyrocko import trace
752 trace.assert_same_sampling_rate(self, tr)
753 tr_new = tr.copy(data=False)
754 tr_new.set_ydata(signal.lfilter(self.b, self.a, tr.get_ydata()))
755 return tr_new
757 @property
758 def summary(self):
759 if self.is_scalar():
760 return str_gain(self.get_scalar())
762 elif len(self.a) == 1:
763 return 'fir{%i,<=%sHz}' % (
764 len(self.b), str_fmax_failsafe(self))
766 else:
767 return 'iir{%i,%i,<=%sHz}' % (
768 len(self.b), len(self.a), str_fmax_failsafe(self))
771class AnalogFilterResponse(FrequencyResponse):
772 '''
773 Frequency response of an analog filter.
775 (see :py:func:`scipy.signal.freqs`).
776 '''
778 b = List.T(Float.T())
779 a = List.T(Float.T())
781 def __init__(self, b, a, **kwargs):
782 FrequencyResponse.__init__(
783 self, b=aslist(b), a=aslist(a), **kwargs)
785 def is_scalar(self):
786 return len(self.a) == 1 and len(self.b) == 1
788 def get_scalar(self):
789 if self.is_scalar():
790 return self.b[0] / self.a[0]
791 else:
792 raise IsNotScalar()
794 def evaluate(self, freqs):
795 return signal.freqs(self.b, self.a, freqs*2.*math.pi)[1]
797 def to_digital(self, deltat, method='bilinear'):
798 from scipy.signal import cont2discrete
799 b, a, _ = cont2discrete((self.b, self.a), deltat, method=method)
800 if b.ndim == 2:
801 b = b[0]
802 return DigitalFilterResponse(b.tolist(), a.tolist(), deltat)
804 @property
805 def summary(self):
806 if self.is_scalar():
807 return str_gain(self.get_scalar())
809 return 'analog{%i,%i,%g}' % (
810 len(self.b), len(self.a), self.get_fmax())
813class MultiplyResponse(FrequencyResponse):
814 '''
815 Multiplication of several :py:class:`FrequencyResponse` objects.
816 '''
818 responses = List.T(FrequencyResponse.T())
820 def __init__(self, responses=None, **kwargs):
821 if responses is None:
822 responses = []
823 FrequencyResponse.__init__(self, responses=responses, **kwargs)
825 def get_fmax(self):
826 fmaxs = [resp.get_fmax() for resp in self.responses]
827 fmaxs = [fmax for fmax in fmaxs if fmax is not None]
828 if not fmaxs:
829 return None
830 else:
831 return min(fmaxs)
833 def evaluate(self, freqs):
834 a = num.ones(freqs.size, dtype=complex)
835 for resp in self.responses:
836 a *= resp.evaluate(freqs)
838 return a
840 def is_scalar(self):
841 return all(resp.is_scalar() for resp in self.responses)
843 def get_scalar(self):
844 '''
845 Get factor if this is a flat response.
846 '''
847 if self.is_scalar():
848 return num.prod(resp.get_scalar() for resp in self.responses)
849 else:
850 raise IsNotScalar()
852 def simplify(self):
853 self.responses = simplify_responses(self.responses)
855 def construction(self):
856 breakpoints = []
857 for resp in self.responses:
858 breakpoints.extend(resp.construction())
860 return finalize_construction(breakpoints)
862 @property
863 def summary(self):
864 if self.is_scalar(self):
865 return str_gain(self.get_scalar())
866 else:
867 xs = [x.summary for x in self.responses]
868 return '(%s)' % ('*'.join(x for x in xs if x != 'one') or 'one')
871class DelayResponse(FrequencyResponse):
872 '''
873 Frequency response of a time delay.
874 '''
876 delay = Float.T(
877 help='Time delay [s]')
879 def evaluate(self, freqs):
880 return num.exp(-2.0J * self.delay * num.pi * freqs)
882 @property
883 def summary(self):
884 return 'delay{%g}' % self.delay
887class InvalidResponseError(Exception):
888 pass
891class InvalidResponse(FrequencyResponse):
893 '''
894 Frequency response returning NaN for all frequencies.
896 When using :py:meth:`FrequencyResponse.evaluate` for the first time after
897 instantiation, the user supplied warning :py:gattr:`message` is emitted.
898 '''
900 message = String.T(
901 help='Warning message to be emitted when the response is used.')
903 def __init__(self, message):
904 FrequencyResponse.__init__(self, message=message)
905 self.have_warned = False
907 def evaluate(self, freqs):
908 if not self.have_warned:
909 logger.warning('Invalid response: %s' % self.message)
910 self.have_warned = True
912 return util.num_full_like(freqs, None, dtype=num.complex)
914 @property
915 def summary(self):
916 return 'invalid'
919def simplify_responses(responses):
921 def unpack_multi(responses):
922 for resp in responses:
923 if isinstance(resp, MultiplyResponse):
924 for sub in unpack_multi(resp.responses):
925 yield sub
926 else:
927 yield resp
929 def cancel_pzs(poles, zeros):
930 poles_new = []
931 zeros_new = list(zeros)
932 for p in poles:
933 try:
934 zeros_new.pop(zeros_new.index(p))
935 except ValueError:
936 poles_new.append(p)
938 return poles_new, zeros_new
940 def combine_pzs(responses):
941 poles = []
942 zeros = []
943 constant = 1.0
944 out = []
945 for resp in responses:
946 if isinstance(resp, PoleZeroResponse):
947 poles.extend(resp.poles)
948 zeros.extend(resp.zeros)
949 constant *= resp.constant
950 else:
951 out.append(resp)
953 poles, zeros = cancel_pzs(poles, zeros)
954 if poles or zeros:
955 out.insert(0, PoleZeroResponse(
956 poles=poles, zeros=zeros, constant=constant))
957 elif constant != 1.0:
958 out.insert(0, Gain(constant=constant))
960 return out
962 def split(xs, condition):
963 out = [], []
964 for x in xs:
965 out[condition(x)].append(x)
967 return out
969 def combine_gains(responses):
970 non_scalars, scalars = split(responses, lambda resp: resp.is_scalar())
971 if scalars:
972 factor = num.prod([resp.get_scalar() for resp in scalars])
973 yield Gain(constant=factor)
975 for resp in non_scalars:
976 yield resp
978 return list(combine_gains(combine_pzs(unpack_multi(responses))))