Coverage for /usr/local/lib/python3.11/dist-packages/grond/analysers/target_balancing/analyser.py: 90%

80 statements  

« prev     ^ index     » next       coverage.py v6.5.0, created at 2024-11-27 15:15 +0000

1# https://pyrocko.org/grond - GPLv3 

2# 

3# The Grond Developers, 21st Century 

4import copy 

5import time 

6import logging 

7import numpy as num 

8from pyrocko.guts import Int, Float, Bool 

9from pyrocko import gf 

10 

11from grond.meta import Forbidden, GrondError 

12 

13from ..base import Analyser, AnalyserConfig, AnalyserResult 

14 

15logger = logging.getLogger('grond.analysers.target_balancer') 

16 

17 

18guts_prefix = 'grond' 

19 

20 

21class TargetBalancingAnalyser(Analyser): 

22 """ Estimating target weights that balance the signal amplitudes. 

23 

24 Signal amplitudes depend on the source-receiver distance, on the 

25 phase type and the taper used. Large signals have in general 

26 a higher contribution to the misfit than smaller signals, 

27 without carrying more information. With this function, weights are 

28 estimated that shall balance the phase contributions. 

29 

30 The weight estimation is based on synthetic waveforms that stem from 

31 a given number of random forward models. The inverse of the mean 

32 synthetic signal amplitudes gives the balancing weight. This is 

33 described as adaptive station weighting in Heimann (2011). 

34 """ 

35 

36 def __init__(self, niter, use_reference_magnitude, cutoff): 

37 Analyser.__init__(self) 

38 self.niter = niter 

39 self.use_reference_magnitude = use_reference_magnitude 

40 self.cutoff = cutoff 

41 

42 def log_progress(self, problem, iiter, niter): 

43 t = time.time() 

44 if self._tlog_last < t - 10. \ 

45 or iiter == 0 \ 

46 or iiter == niter - 1: 

47 

48 logger.info( 

49 'Target balancing for "%s" at %i/%i.' % ( 

50 problem.name, 

51 iiter, niter)) 

52 

53 self._tlog_last = t 

54 

55 def analyse(self, problem, ds): 

56 if self.niter == 0: 

57 return 

58 

59 wtargets = [] 

60 if not problem.has_waveforms: 

61 return 

62 

63 for target in problem.waveform_targets: 

64 wtarget = copy.copy(target) 

65 wtarget.flip_norm = True 

66 wtarget.weight = 1.0 

67 wtargets.append(wtarget) 

68 

69 wproblem = problem.copy() 

70 wproblem.targets = wtargets 

71 

72 xbounds = wproblem.get_parameter_bounds() 

73 

74 misfits = num.zeros((self.niter, wproblem.ntargets, 2)) 

75 rstate = num.random.RandomState(123) 

76 

77 isbad_mask = None 

78 

79 self._tlog_last = 0 

80 for iiter in range(self.niter): 

81 self.log_progress(problem, iiter, self.niter) 

82 while True: 

83 if self.use_reference_magnitude: 

84 try: 

85 fixed_magnitude = wproblem.base_source.get_magnitude() 

86 except gf.DerivedMagnitudeError: 

87 raise GrondError( 

88 'Cannot use use_reference_magnitude for this type ' 

89 'of source model.') 

90 else: 

91 fixed_magnitude = None 

92 

93 x = wproblem.random_uniform( 

94 xbounds, rstate, fixed_magnitude=fixed_magnitude) 

95 

96 try: 

97 x = wproblem.preconstrain(x) 

98 break 

99 

100 except Forbidden: 

101 pass 

102 

103 if isbad_mask is not None and num.any(isbad_mask): 

104 isok_mask = num.logical_not(isbad_mask) 

105 else: 

106 isok_mask = None 

107 misfits[iiter, :, :] = wproblem.misfits(x, mask=isok_mask) 

108 

109 isbad_mask = num.isnan(misfits[iiter, :, 1]) 

110 

111 mean_ms = num.mean(misfits[:, :, 0], axis=0) 

112 

113 mean_ps = num.mean(misfits[:, :, 1], axis=0) 

114 

115 weights = 1. / mean_ps 

116 families, nfamilies = wproblem.get_family_mask() 

117 

118 for ifamily in range(nfamilies): 

119 weights[families == ifamily] /= ( 

120 num.nansum(weights[families == ifamily]) / 

121 num.nansum(num.isfinite(weights[families == ifamily]))) 

122 

123 if self.cutoff is not None: 

124 weights[mean_ms / mean_ps > self.cutoff] = 0.0 

125 

126 for weight, target in zip(weights, problem.waveform_targets): 

127 target.analyser_results['target_balancing'] = \ 

128 TargetBalancingAnalyserResult(weight=float(weight)) 

129 

130 for itarget, target in enumerate(problem.waveform_targets): 

131 logger.info(( 

132 'Balancing analysis for target "%s":\n' 

133 ' m/p: %g\n' 

134 ' weight: %g\n' 

135 ) % ( 

136 target.string_id(), 

137 mean_ms[itarget] / mean_ps[itarget], 

138 weights[itarget])) 

139 

140 

141class TargetBalancingAnalyserResult(AnalyserResult): 

142 weight = Float.T() 

143 

144 

145class TargetBalancingAnalyserConfig(AnalyserConfig): 

146 """Configuration parameters of the target balancing.""" 

147 niterations = Int.T( 

148 default=1000, 

149 help='Number of random forward models for mean phase amplitude ' 

150 'estimation') 

151 

152 use_reference_magnitude = Bool.T( 

153 default=False, 

154 help='Fix magnitude of random sources to the magnitude of the ' 

155 'reference event.') 

156 

157 cutoff = Float.T( 

158 optional=True, 

159 help='Remove targets where ratio m/p > cutoff, where m is the misfit ' 

160 'between synthetics and observations and p is the misfit between ' 

161 'synthetics and zero-traces. Magnitude should be fixed to use ' 

162 'this.') 

163 

164 def get_analyser(self): 

165 return TargetBalancingAnalyser( 

166 niter=self.niterations, 

167 use_reference_magnitude=self.use_reference_magnitude, 

168 cutoff=self.cutoff) 

169 

170 

171__all__ = ''' 

172 TargetBalancingAnalyser 

173 TargetBalancingAnalyserConfig 

174'''.split()