Coverage for /usr/local/lib/python3.11/dist-packages/grond/targets/waveform_phase_ratio/plot.py: 21%
113 statements
« prev ^ index » next coverage.py v6.5.0, created at 2023-10-26 16:25 +0000
« prev ^ index » next coverage.py v6.5.0, created at 2023-10-26 16:25 +0000
1import numpy as num
2from matplotlib import pyplot as plt
3from matplotlib import cm
5from pyrocko.guts import Tuple, Float, String, Int
6from pyrocko.plot import mpl_init, mpl_margins
7from pyrocko import gf
9from grond import core, meta
10from grond.plot.config import PlotConfig
11from grond.plot.collection import PlotItem
13from .target import PhaseRatioTarget
15guts_prefix = 'grond'
18def S(text):
19 return ' '.join(text.split())
22class FitsPhaseRatioPlot(PlotConfig):
23 '''Plot showing the phase ratio fits for the best model.'''
25 name = 'fits_phase_ratio'
27 size_cm = Tuple.T(
28 2, Float.T(),
29 default=(15., 7.5),
30 help='Width and length of the figure in [cm]')
32 misfit_cutoff = Float.T(
33 optional=True,
34 help='Plot fits for models up to this misfit value.')
36 color_parameter = String.T(
37 default='misfit',
38 help='Choice of value to color, options: "misfit" (default), '
39 '"dist" or any source parameter name.')
41 istride_ensemble = Int.T(
42 default=10,
43 help='Stride value N to choose every Nth model from the solution '
44 'ensemble.')
46 font_size_title = Float.T(
47 default=10,
48 help='Font size of title [pt]')
50 def make(self, environ):
51 cm = environ.get_plot_collection_manager()
52 mpl_init(fontsize=self.font_size)
54 environ.setup_modelling()
55 ds = environ.get_dataset()
56 history = environ.get_history(subset='harvest')
58 scolor = {
59 'misfit': S('''
60 The synthetic markers are colored according to their
61 respective global (non-bootstrapped) misfit value. Red
62 indicates better fit, blue worse.'''),
64 'dist': S('''
65 The synthetic markers are colored according to their
66 Mahalanobis distance from the mean solution.''')
68 }.get(self.color_parameter, S('''
69 The synthetic markers are colored according to source
70 parameter "%s".''' % self.color_parameter))
72 cm.create_group_mpl(
73 self, self.draw_figures(ds, history),
74 title=u'Fits Phase Ratios',
75 section='fits',
76 feather_icon='activity',
77 description=u'''
78Observed (black markers) and synthetic waveform amplitude phase ratio estimates
79(spectral average ratio; colored markers) at different stations for every Nth
80model in the bootstrap solution ensemble (N=%i).
82%s
84The frequency range used to estimate spectral averages is not shown (see config
85file). Optimal solutions show good agreement between black and colored markers.
86''' % (self.istride_ensemble, scolor))
88 def draw_figures(self, ds, history):
89 problem = history.problem
91 for target in problem.targets:
92 target.set_dataset(ds)
94 targets = [
95 t for t in problem.targets if isinstance(t, PhaseRatioTarget)]
97 tpaths = sorted(set(t.path for t in targets))
99 for tpath in tpaths:
100 for (item, fig) in self.draw_figure(ds, history, tpath):
101 yield item, fig
103 def draw_figure(self, ds, history, tpath):
104 problem = history.problem
105 color_parameter = self.color_parameter
106 misfit_cutoff = self.misfit_cutoff
107 fontsize = self.font_size
108 targets = [
109 t for t in problem.targets
110 if isinstance(t, PhaseRatioTarget) and t.path == tpath]
112 gms = history.get_sorted_misfits(chain=0)[::-1]
113 models = history.get_sorted_models(chain=0)[::-1]
115 if misfit_cutoff is not None:
116 ibest = gms < misfit_cutoff
117 gms = gms[ibest]
118 models = models[ibest]
120 gms = gms[::self.istride_ensemble]
121 models = models[::self.istride_ensemble]
123 nmodels = models.shape[0]
124 if color_parameter == 'dist':
125 mx = num.mean(models, axis=0)
126 cov = num.cov(models.T)
127 mdists = core.mahalanobis_distance(models, mx, cov)
128 icolor = meta.ordersort(mdists)
130 elif color_parameter == 'misfit':
131 iorder = num.arange(nmodels)
132 icolor = iorder
134 elif color_parameter in problem.parameter_names:
135 ind = problem.name_to_index(color_parameter)
136 icolor = problem.extract(models, ind)
138 from matplotlib import colors
139 cmap = cm.ScalarMappable(
140 norm=colors.Normalize(vmin=num.min(icolor), vmax=num.max(icolor)),
141 cmap=plt.get_cmap('coolwarm'))
143 imodel_to_color = []
144 for imodel in range(nmodels):
145 imodel_to_color.append(cmap.to_rgba(icolor[imodel]))
147 data = []
148 for imodel in range(nmodels):
149 model = models[imodel, :]
151 # source = problem.get_source(model)
152 results = problem.evaluate(model, targets=targets)
154 for target, result in zip(targets, results):
155 if isinstance(result, gf.SeismosizerError):
156 continue
158 if not isinstance(target, PhaseRatioTarget):
159 continue
161 a_obs = result.a_obs
162 b_obs = result.b_obs
163 a_syn = result.a_syn
164 b_syn = result.b_syn
166 r_obs = a_obs / (a_obs + b_obs)
167 r_syn = a_syn / (a_syn + b_syn)
169 data.append(('.'.join(target.codes), imodel, r_obs, r_syn))
171 fontsize = self.font_size
173 item = PlotItem(
174 name='fig_%s' % tpath)
176 item.attributes['targets'] = [
177 t.string_id() for t in targets]
179 fig = plt.figure(figsize=self.size_inch)
181 labelpos = mpl_margins(
182 fig, nw=1, nh=1, left=7., right=1., bottom=10., top=3,
183 units=fontsize)
185 axes = fig.add_subplot(1, 1, 1)
186 labelpos(axes, 2.5, 2.0)
188 labels = sorted(set(x[0] for x in data))
190 ntargets = len(labels)
191 string_id_to_itarget = dict((x, i) for (i, x) in enumerate(labels))
193 itargets = num.array([string_id_to_itarget[x[0]] for x in data])
195 imodels = num.array([x[1] for x in data], dtype=int).T
196 r_obs, r_syn = num.array([x[2:] for x in data]).T
198 r_obs_median = num.zeros(ntargets)
199 for itarget in range(ntargets):
200 r_obs_median[itarget] = num.median(r_obs[itargets == itarget])
202 iorder = meta.ordersort(r_obs_median)
204 for imodel in range(nmodels):
205 mask = imodels == imodel
206 axes.plot(
207 iorder[itargets[mask]], r_obs[mask], '_',
208 ms=20.,
209 zorder=-10,
210 alpha=0.5,
211 color='black')
212 axes.plot(
213 iorder[itargets[mask]], r_syn[mask], '_',
214 ms=10.,
215 alpha=0.5,
216 color=imodel_to_color[imodel])
218 axes.set_yscale('log')
219 axes.set_ylabel('Ratio')
221 axes.set_xticks(
222 iorder[num.arange(ntargets)])
223 axes.set_xticklabels(labels, rotation='vertical')
225 fig.suptitle(tpath, fontsize=self.font_size_title)
227 yield item, fig
230def get_plot_classes():
231 return [
232 FitsPhaseRatioPlot,
233 ]