Coverage for /usr/local/lib/python3.11/dist-packages/grond/targets/base.py: 85%
124 statements
« prev ^ index » next coverage.py v6.5.0, created at 2023-10-26 16:25 +0000
« prev ^ index » next coverage.py v6.5.0, created at 2023-10-26 16:25 +0000
1import copy
3import numpy as num
5from pyrocko import gf
6from pyrocko.guts_array import Array
7from pyrocko.guts import Object, Float, String, Dict, List, Choice, load, dump
9from grond.analysers.base import AnalyserResult
10from grond.meta import has_get_plot_classes, GrondError
13guts_prefix = 'grond'
16class TargetGroup(Object):
17 normalisation_family = gf.StringID.T(
18 optional=True,
19 help='Group with common misfit normalisation')
20 path = gf.StringID.T(
21 optional=True,
22 help='Targets.id will be prefixed with this path')
23 weight = Float.T(
24 default=1.0,
25 help='Additional manual weight of the target group')
26 interpolation = gf.InterpolationMethod.T(
27 default='nearest_neighbor',
28 help='Interpolation from pre-calculated GF store.')
29 store_id = gf.StringID.T(
30 optional=True,
31 help='ID of the Green\'s function store for this TargetGroup.')
33 def get_targets(self, ds, event, default_path='none'):
34 if not self._targets:
35 raise NotImplementedError()
38class MisfitResult(Object):
39 misfits = Array.T(
40 shape=(None, 2),
41 dtype=num.float64)
44class MisfitConfig(Object):
45 pass
48@has_get_plot_classes
49class MisfitTarget(Object):
51 manual_weight = Float.T(
52 default=1.0,
53 help='Relative weight of this target')
54 analyser_results = Dict.T(
55 gf.StringID.T(),
56 AnalyserResult.T(),
57 help='Dictionary of analyser results')
58 normalisation_family = gf.StringID.T(
59 optional=True,
60 help='Normalisation family of this misfit target')
61 path = gf.StringID.T(
62 help='A path identifier used for plotting')
63 misfit_config = MisfitConfig.T(
64 default=MisfitConfig.D(),
65 help='Misfit configuration')
66 bootstrap_weights = Array.T(
67 dtype=num.float64,
68 serialize_as='base64',
69 optional=True)
70 bootstrap_residuals = Array.T(
71 dtype=num.float64,
72 serialize_as='base64',
73 optional=True)
75 can_bootstrap_weights = False
76 can_bootstrap_residuals = False
78 plot_misfits_cumulative = True
80 def __init__(self, **kwargs):
81 Object.__init__(self, **kwargs)
82 self.parameters = []
84 self._ds = None
85 self._result_mode = 'sparse'
87 self._combined_weight = None
88 self._target_parameters = None
89 self._target_ranges = None
91 self._combined_weight = None
93 @classmethod
94 def get_plot_classes(cls):
95 return []
97 def set_dataset(self, ds):
98 self._ds = ds
100 def get_dataset(self):
101 return self._ds
103 def string_id(self):
104 return str(self.path)
106 def misfits_string_ids(self):
107 raise NotImplementedError('%s does not implement misfits_string_id'
108 % self.__class__.__name__)
110 @property
111 def nmisfits(self):
112 return 1
114 def noise_weight_matrix(self):
115 return num.array([[1]])
117 @property
118 def nparameters(self):
119 if self._target_parameters is None:
120 return 0
121 return len(self._target_parameters)
123 @property
124 def target_parameters(self):
125 if self._target_parameters is None:
126 self._target_parameters = copy.deepcopy(self.parameters)
127 for p in self._target_parameters:
128 p.set_groups([self.string_id()])
129 return self._target_parameters
131 @property
132 def target_ranges(self):
133 return {}
135 def set_parameter_values(self, model):
136 for i, p in enumerate(self.parameters):
137 self.parameter_values[p.name_nogroups] = model[i]
139 def set_result_mode(self, result_mode):
140 self._result_mode = result_mode
142 def post_process(self, engine, source, statics):
143 raise NotImplementedError()
145 def get_combined_weight(self):
146 if self._combined_weight is None:
147 w = self.manual_weight
148 for analyser in self.analyser_results.values():
149 w *= analyser.weight
150 self._combined_weight = num.array([w], dtype=float)
152 return self._combined_weight
154 def get_correlated_weights(self):
155 pass
157 def set_bootstrap_weights(self, weights):
158 self.bootstrap_weights = weights
160 def get_bootstrap_weights(self):
161 if self.bootstrap_weights is None:
162 raise Exception('Bootstrap weights have not been set!')
163 nbootstraps = self.bootstrap_weights.size // self.nmisfits
164 return self.bootstrap_weights.reshape(nbootstraps, self.nmisfits)
166 def init_bootstrap_residuals(self, nbootstrap, rstate=None):
167 raise NotImplementedError()
169 def set_bootstrap_residuals(self, residuals):
170 self.bootstrap_residuals = residuals
172 def get_bootstrap_residuals(self):
173 if self.bootstrap_residuals is None:
174 raise Exception('Bootstrap residuals have not been set!')
175 nbootstraps = self.bootstrap_residuals.size // self.nmisfits
176 return self.bootstrap_residuals.reshape(nbootstraps, self.nmisfits)
178 def prepare_modelling(self, engine, source, targets):
179 ''' Prepare modelling target
181 This function shall return a list of :class:`pyrocko.gf.Target`
182 for forward modelling in the :class:`pyrocko.gf.LocalEngine`.
183 '''
184 return [self]
186 def finalize_modelling(
187 self, engine, source, modelling_targets, modelling_results):
188 ''' Manipulate modelling before misfit calculation
190 This function can be overloaded interact with the modelling results.
191 '''
192 return modelling_results[0]
195class MisfitResultError(Object):
196 message = String.T()
199class MisfitResultCollection(Object):
200 results = List.T(List.T(
201 Choice.T([MisfitResult.T(), MisfitResultError.T()])))
204def dump_misfit_result_collection(misfit_result_collection, path):
205 dump(misfit_result_collection, filename=path)
208def load_misfit_result_collection(path):
209 try:
210 obj = load(filename=path)
212 except OSError as e:
213 raise GrondError(
214 'Failed to read ensemble misfit results from file "%s" (%s)' % (
215 path, e))
217 if not isinstance(obj, MisfitResultCollection):
218 raise GrondError(
219 'File "%s" does not contain any misfit result collection.' % path)
221 return obj
224__all__ = '''
225 TargetGroup
226 MisfitTarget
227 MisfitResult
228 MisfitResultError
229 dump_misfit_result_collection
230 load_misfit_result_collection
231 MisfitResultCollection
232'''.split()