Coverage for /usr/local/lib/python3.11/dist-packages/grond/targets/base.py: 85%

124 statements  

« prev     ^ index     » next       coverage.py v6.5.0, created at 2023-10-26 16:25 +0000

1import copy 

2 

3import numpy as num 

4 

5from pyrocko import gf 

6from pyrocko.guts_array import Array 

7from pyrocko.guts import Object, Float, String, Dict, List, Choice, load, dump 

8 

9from grond.analysers.base import AnalyserResult 

10from grond.meta import has_get_plot_classes, GrondError 

11 

12 

13guts_prefix = 'grond' 

14 

15 

16class TargetGroup(Object): 

17 normalisation_family = gf.StringID.T( 

18 optional=True, 

19 help='Group with common misfit normalisation') 

20 path = gf.StringID.T( 

21 optional=True, 

22 help='Targets.id will be prefixed with this path') 

23 weight = Float.T( 

24 default=1.0, 

25 help='Additional manual weight of the target group') 

26 interpolation = gf.InterpolationMethod.T( 

27 default='nearest_neighbor', 

28 help='Interpolation from pre-calculated GF store.') 

29 store_id = gf.StringID.T( 

30 optional=True, 

31 help='ID of the Green\'s function store for this TargetGroup.') 

32 

33 def get_targets(self, ds, event, default_path='none'): 

34 if not self._targets: 

35 raise NotImplementedError() 

36 

37 

38class MisfitResult(Object): 

39 misfits = Array.T( 

40 shape=(None, 2), 

41 dtype=num.float64) 

42 

43 

44class MisfitConfig(Object): 

45 pass 

46 

47 

48@has_get_plot_classes 

49class MisfitTarget(Object): 

50 

51 manual_weight = Float.T( 

52 default=1.0, 

53 help='Relative weight of this target') 

54 analyser_results = Dict.T( 

55 gf.StringID.T(), 

56 AnalyserResult.T(), 

57 help='Dictionary of analyser results') 

58 normalisation_family = gf.StringID.T( 

59 optional=True, 

60 help='Normalisation family of this misfit target') 

61 path = gf.StringID.T( 

62 help='A path identifier used for plotting') 

63 misfit_config = MisfitConfig.T( 

64 default=MisfitConfig.D(), 

65 help='Misfit configuration') 

66 bootstrap_weights = Array.T( 

67 dtype=num.float64, 

68 serialize_as='base64', 

69 optional=True) 

70 bootstrap_residuals = Array.T( 

71 dtype=num.float64, 

72 serialize_as='base64', 

73 optional=True) 

74 

75 can_bootstrap_weights = False 

76 can_bootstrap_residuals = False 

77 

78 plot_misfits_cumulative = True 

79 

80 def __init__(self, **kwargs): 

81 Object.__init__(self, **kwargs) 

82 self.parameters = [] 

83 

84 self._ds = None 

85 self._result_mode = 'sparse' 

86 

87 self._combined_weight = None 

88 self._target_parameters = None 

89 self._target_ranges = None 

90 

91 self._combined_weight = None 

92 

93 @classmethod 

94 def get_plot_classes(cls): 

95 return [] 

96 

97 def set_dataset(self, ds): 

98 self._ds = ds 

99 

100 def get_dataset(self): 

101 return self._ds 

102 

103 def string_id(self): 

104 return str(self.path) 

105 

106 def misfits_string_ids(self): 

107 raise NotImplementedError('%s does not implement misfits_string_id' 

108 % self.__class__.__name__) 

109 

110 @property 

111 def nmisfits(self): 

112 return 1 

113 

114 def noise_weight_matrix(self): 

115 return num.array([[1]]) 

116 

117 @property 

118 def nparameters(self): 

119 if self._target_parameters is None: 

120 return 0 

121 return len(self._target_parameters) 

122 

123 @property 

124 def target_parameters(self): 

125 if self._target_parameters is None: 

126 self._target_parameters = copy.deepcopy(self.parameters) 

127 for p in self._target_parameters: 

128 p.set_groups([self.string_id()]) 

129 return self._target_parameters 

130 

131 @property 

132 def target_ranges(self): 

133 return {} 

134 

135 def set_parameter_values(self, model): 

136 for i, p in enumerate(self.parameters): 

137 self.parameter_values[p.name_nogroups] = model[i] 

138 

139 def set_result_mode(self, result_mode): 

140 self._result_mode = result_mode 

141 

142 def post_process(self, engine, source, statics): 

143 raise NotImplementedError() 

144 

145 def get_combined_weight(self): 

146 if self._combined_weight is None: 

147 w = self.manual_weight 

148 for analyser in self.analyser_results.values(): 

149 w *= analyser.weight 

150 self._combined_weight = num.array([w], dtype=float) 

151 

152 return self._combined_weight 

153 

154 def get_correlated_weights(self): 

155 pass 

156 

157 def set_bootstrap_weights(self, weights): 

158 self.bootstrap_weights = weights 

159 

160 def get_bootstrap_weights(self): 

161 if self.bootstrap_weights is None: 

162 raise Exception('Bootstrap weights have not been set!') 

163 nbootstraps = self.bootstrap_weights.size // self.nmisfits 

164 return self.bootstrap_weights.reshape(nbootstraps, self.nmisfits) 

165 

166 def init_bootstrap_residuals(self, nbootstrap, rstate=None): 

167 raise NotImplementedError() 

168 

169 def set_bootstrap_residuals(self, residuals): 

170 self.bootstrap_residuals = residuals 

171 

172 def get_bootstrap_residuals(self): 

173 if self.bootstrap_residuals is None: 

174 raise Exception('Bootstrap residuals have not been set!') 

175 nbootstraps = self.bootstrap_residuals.size // self.nmisfits 

176 return self.bootstrap_residuals.reshape(nbootstraps, self.nmisfits) 

177 

178 def prepare_modelling(self, engine, source, targets): 

179 ''' Prepare modelling target 

180 

181 This function shall return a list of :class:`pyrocko.gf.Target` 

182 for forward modelling in the :class:`pyrocko.gf.LocalEngine`. 

183 ''' 

184 return [self] 

185 

186 def finalize_modelling( 

187 self, engine, source, modelling_targets, modelling_results): 

188 ''' Manipulate modelling before misfit calculation 

189 

190 This function can be overloaded interact with the modelling results. 

191 ''' 

192 return modelling_results[0] 

193 

194 

195class MisfitResultError(Object): 

196 message = String.T() 

197 

198 

199class MisfitResultCollection(Object): 

200 results = List.T(List.T( 

201 Choice.T([MisfitResult.T(), MisfitResultError.T()]))) 

202 

203 

204def dump_misfit_result_collection(misfit_result_collection, path): 

205 dump(misfit_result_collection, filename=path) 

206 

207 

208def load_misfit_result_collection(path): 

209 try: 

210 obj = load(filename=path) 

211 

212 except OSError as e: 

213 raise GrondError( 

214 'Failed to read ensemble misfit results from file "%s" (%s)' % ( 

215 path, e)) 

216 

217 if not isinstance(obj, MisfitResultCollection): 

218 raise GrondError( 

219 'File "%s" does not contain any misfit result collection.' % path) 

220 

221 return obj 

222 

223 

224__all__ = ''' 

225 TargetGroup 

226 MisfitTarget 

227 MisfitResult 

228 MisfitResultError 

229 dump_misfit_result_collection 

230 load_misfit_result_collection 

231 MisfitResultCollection 

232'''.split()