Coverage for /usr/local/lib/python3.11/dist-packages/grond/optimisers/highscore/plot.py: 18%
221 statements
« prev ^ index » next coverage.py v6.5.0, created at 2024-11-27 15:15 +0000
« prev ^ index » next coverage.py v6.5.0, created at 2024-11-27 15:15 +0000
1# https://pyrocko.org/grond - GPLv3
2#
3# The Grond Developers, 21st Century
4import logging
5import numpy as num
7from matplotlib import pyplot as plt
8from matplotlib.ticker import FuncFormatter
10from pyrocko.plot import mpl_init, mpl_margins, mpl_color
11from pyrocko.guts import Tuple, Float
12from pyrocko import trace
14from grond.plot.config import PlotConfig
15from grond.plot.collection import PlotItem
17logger = logging.getLogger('grond.optimiser.highscore.plot')
19guts_prefix = 'grond'
22def _pcolormesh_same_dim(ax, x, y, v, **kwargs):
23 # x, y, v must have the same dimension
24 try:
25 return ax.pcolormesh(x, y, v, shading='nearest', **kwargs)
26 except TypeError:
27 # matplotlib versions < 3.3
28 return ax.pcolormesh(x, y, v[:-1, :-1], **kwargs)
31class HighScoreOptimiserPlot(object):
33 def __init__(
34 self, optimiser, problem, history, xpar_name, ypar_name,
35 movie_filename):
37 self.optimiser = optimiser
38 self.problem = problem
39 self.chains = optimiser.chains(problem, history)
40 self.history = history
41 self.xpar_name = xpar_name
42 self.ypar_name = ypar_name
43 self.fontsize = 10.
44 self.movie_filename = movie_filename
45 self.show = False
46 self.iiter = 0
47 self.iiter_last_draw = 0
48 self._volatile = []
49 self._blocks_complete = set()
51 def start(self):
52 nfx = 1
53 nfy = 1
55 problem = self.problem
57 ixpar = problem.name_to_index(self.xpar_name)
58 iypar = problem.name_to_index(self.ypar_name)
60 mpl_init(fontsize=self.fontsize)
61 fig = plt.figure(figsize=(9.6, 5.4))
62 labelpos = mpl_margins(fig, nw=nfx, nh=nfy, w=7., h=5., wspace=7.,
63 hspace=2., units=self.fontsize)
65 xpar = problem.parameters[ixpar]
66 ypar = problem.parameters[iypar]
68 if xpar.unit == ypar.unit:
69 axes = fig.add_subplot(nfy, nfx, 1, aspect=1.0)
70 else:
71 axes = fig.add_subplot(nfy, nfx, 1)
73 labelpos(axes, 2.5, 2.0)
75 axes.set_xlabel(xpar.get_label())
76 axes.set_ylabel(ypar.get_label())
78 axes.get_xaxis().set_major_locator(plt.MaxNLocator(4))
79 axes.get_yaxis().set_major_locator(plt.MaxNLocator(4))
81 xref = problem.get_reference_model()
82 axes.axvline(xpar.scaled(xref[ixpar]), color='black', alpha=0.3)
83 axes.axhline(ypar.scaled(xref[iypar]), color='black', alpha=0.3)
85 self.fig = fig
86 self.problem = problem
87 self.xpar = xpar
88 self.ypar = ypar
89 self.axes = axes
90 self.ixpar = ixpar
91 self.iypar = iypar
92 from matplotlib import colors
93 n = self.optimiser.nbootstrap + 1
94 hsv = num.vstack((
95 num.random.uniform(0., 1., n),
96 num.random.uniform(0.5, 0.9, n),
97 num.repeat(0.7, n))).T
99 self.bcolors = colors.hsv_to_rgb(hsv[num.newaxis, :, :])[0, :, :]
100 self.bcolors[0, :] = [0., 0., 0.]
102 bounds = self.problem.get_combined_bounds()
104 from grond import plot
105 self.xlim = plot.fixlim(*xpar.scaled(bounds[ixpar]))
106 self.ylim = plot.fixlim(*ypar.scaled(bounds[iypar]))
108 self.set_limits()
110 from matplotlib.colors import LinearSegmentedColormap
112 self.cmap = LinearSegmentedColormap.from_list('probability', [
113 (1.0, 1.0, 1.0),
114 (0.5, 0.9, 0.6)])
116 self.writer = None
117 if self.movie_filename:
118 from matplotlib.animation import FFMpegWriter
120 metadata = dict(title=problem.name, artist='Grond')
122 self.writer = FFMpegWriter(
123 fps=30,
124 metadata=metadata,
125 codec='libx264',
126 bitrate=200000,
127 extra_args=[
128 '-pix_fmt', 'yuv420p',
129 '-profile:v', 'baseline',
130 '-level', '3',
131 '-an'])
133 self.writer.setup(self.fig, self.movie_filename, dpi=200)
135 if self.show:
136 plt.ion()
137 plt.show()
139 def set_limits(self):
140 self.axes.autoscale(False)
141 self.axes.set_xlim(*self.xlim)
142 self.axes.set_ylim(*self.ylim)
144 def draw_frame(self):
146 self.chains.goto(self.iiter+1)
147 msize = 15.
149 for artist in self._volatile:
150 artist.remove()
152 self._volatile[:] = []
154 nblocks = self.iiter // 100 + 1
156 models = self.history.models[:self.iiter+1]
158 for iblock in range(nblocks):
159 if iblock in self._blocks_complete:
160 continue
162 models_add = self.history.models[
163 iblock*100:min((iblock+1)*100, self.iiter+1)]
165 fx = self.problem.extract(models_add, self.ixpar)
166 fy = self.problem.extract(models_add, self.iypar)
167 collection = self.axes.scatter(
168 self.xpar.scaled(fx),
169 self.ypar.scaled(fy),
170 color='black',
171 s=msize * 0.15, alpha=0.2, edgecolors='none')
173 if models_add.shape[0] != 100:
174 self._volatile.append(collection)
175 else:
176 self._blocks_complete.add(iblock)
178 for ichain in range(self.chains.nchains):
180 iiters = self.chains.indices(ichain)
181 fx = self.problem.extract(models[iiters, :], self.ixpar)
182 fy = self.problem.extract(models[iiters, :], self.iypar)
184 nfade = 20
185 t1 = num.maximum(0.0, iiters - (models.shape[0] - nfade)) / nfade
186 factors = num.sqrt(1.0 - t1) * (1.0 + 15. * t1**2)
188 msizes = msize * factors
190 paths = self.axes.scatter(
191 self.xpar.scaled(fx),
192 self.ypar.scaled(fy),
193 color=self.bcolors[ichain],
194 s=msizes, alpha=0.5, edgecolors='none')
196 self._volatile.append(paths)
198 _, phase, iiter_phase = self.optimiser.get_sampler_phase(self.iiter)
200 np = 1000
201 models_prob = num.zeros((np, self.problem.nparameters))
202 for ip in range(np):
203 models_prob[ip, :] = phase.get_sample(
204 self.problem, iiter_phase, self.chains)
206 fx = self.problem.extract(models_prob, self.ixpar)
207 fy = self.problem.extract(models_prob, self.iypar)
209 if False:
211 bounds = self.problem.get_combined_bounds()
213 nx = 20
214 ny = 20
215 x_edges = num.linspace(
216 bounds[self.ixpar][0], bounds[self.ixpar][1], nx)
217 y_edges = num.linspace(
218 bounds[self.iypar][0], bounds[self.iypar][1], ny)
220 p, _, _ = num.histogram2d(fx, fy, bins=(x_edges, y_edges))
221 x, y = num.meshgrid(x_edges, y_edges)
223 artist = self.axes.pcolormesh(
224 self.xpar.scaled(x),
225 self.ypar.scaled(y),
226 p, cmap=self.cmap, zorder=-1)
228 self._volatile.append(artist)
230 else:
231 collection = self.axes.scatter(
232 self.xpar.scaled(fx),
233 self.ypar.scaled(fy),
234 color='green',
235 s=msize * 0.15, alpha=0.2, edgecolors='none')
237 self._volatile.append(collection)
239 if self.writer:
240 self.writer.grab_frame()
242 artist = self.axes.annotate(
243 '%i (%s)' % (self.iiter+1, phase.__class__.__name__),
244 xy=(0., 1.),
245 xycoords='axes fraction',
246 xytext=(self.fontsize/2., -self.fontsize/2.),
247 textcoords='offset points',
248 ha='left',
249 va='top',
250 fontsize=self.fontsize,
251 fontstyle='normal')
253 self._volatile.append(artist)
255 if self.show:
256 plt.draw()
258 self.iiter_last_draw = self.iiter + 1
260 def finish(self):
261 if self.writer:
262 self.writer.finish()
264 if self.show:
265 plt.show()
266 plt.ioff()
268 def render(self):
269 self.start()
271 while self.iiter < self.history.nmodels:
272 logger.info('Rendering frame %i/%i.'
273 % (self.iiter+1, self.history.nmodels))
274 self.draw_frame()
275 self.iiter += 1
277 self.finish()
280def rolling_window(a, window):
281 shape = a.shape[:-1] + (a.shape[-1] - window + 1, window)
282 strides = a.strides + (a.strides[-1],)
283 return num.lib.stride_tricks.as_strided(a, shape=shape, strides=strides)
286class HighScoreAcceptancePlot(PlotConfig):
287 '''Model acceptance plot '''
288 name = 'acceptance'
289 size_cm = Tuple.T(2, Float.T(), default=(21., 14.9))
291 def make(self, environ):
292 cm = environ.get_plot_collection_manager()
293 cm.create_group_mpl(
294 self,
295 self.draw_figures(environ),
296 title=u'Acceptance',
297 section='optimiser',
298 description=u'''
299Model acceptance and accepted model popularities.
301The plots in this section can be used to investigate performance and
302characteristics of the optimisation algorithm.
303''',
304 feather_icon='check')
306 def draw_figures(self, environ):
307 nwindow = 200
308 show_raw_acceptance_rates = False
309 optimiser = environ.get_optimiser()
310 problem = environ.get_problem()
311 history = environ.get_history()
312 chains = optimiser.chains(problem, history)
313 chains.load()
315 acceptance = chains.acceptance_history
317 nmodels_rate = history.nmodels - (nwindow - 1)
318 if nmodels_rate < 1:
319 logger.warning(
320 'Cannot create plot acceptance: insufficient number of tested '
321 'models.')
323 return
325 acceptance_rate = num.zeros((history.nchains, nmodels_rate))
326 for ichain in range(history.nchains):
327 acceptance_rate[ichain, :] = trace.moving_sum(
328 acceptance[ichain, :], nwindow, mode='valid') / float(nwindow)
330 acceptance_n = num.sum(acceptance, axis=0)
332 acceptance_any = num.minimum(acceptance_n, 1)
334 acceptance_any_rate = trace.moving_sum(
335 acceptance_any, nwindow, mode='valid') / float(nwindow)
337 acceptance_p = acceptance_n / float(history.nchains)
339 popularity = trace.moving_sum(
340 acceptance_p, nwindow, mode='valid') \
341 / float(nwindow) / acceptance_any_rate
343 mpl_init(fontsize=self.font_size)
344 fig = plt.figure(figsize=self.size_inch)
345 labelpos = mpl_margins(fig, w=7., h=5., units=self.font_size)
347 axes = fig.add_subplot(1, 1, 1)
348 labelpos(axes, 2.5, 2.0)
350 imodels = num.arange(history.nmodels)
352 imodels_rate = imodels[nwindow-1:]
354 axes.plot(
355 acceptance_n/history.nchains * 100.,
356 '.',
357 ms=2.0,
358 color=mpl_color('skyblue2'),
359 label='Popularity of Accepted Models',
360 alpha=0.3)
362 if show_raw_acceptance_rates:
363 for ichain in range(chains.nchains):
364 axes.plot(imodels_rate, acceptance_rate[ichain, :]*100.,
365 color=mpl_color('scarletred2'), alpha=0.2)
367 axes.plot(
368 imodels_rate,
369 popularity * 100.,
370 color=mpl_color('skyblue2'),
371 label='Popularity (moving average)')
372 axes.plot(
373 imodels_rate,
374 acceptance_any_rate*100.,
375 color='black',
376 label='Acceptance Rate (any chain)')
378 axes.legend()
380 axes.set_xlabel('Iteration')
381 axes.set_ylabel('Acceptance Rate, Model Popularity')
383 axes.set_ylim(0., 100.)
384 axes.set_xlim(0., history.nmodels - 1)
385 axes.grid(alpha=.2)
386 axes.yaxis.set_major_formatter(FuncFormatter(lambda v, p: '%d%%' % v))
388 iiter = 0
389 bgcolors = [mpl_color('aluminium1'), mpl_color('aluminium2')]
390 for iphase, phase in enumerate(optimiser.sampler_phases):
391 axes.axvspan(
392 iiter, iiter+phase.niterations,
393 color=bgcolors[iphase % len(bgcolors)])
395 iiter += phase.niterations
397 yield (
398 PlotItem(
399 name='acceptance',
400 description=u'''
401Acceptance rate (black line) within a moving window of %d iterations.
403A model is considered accepted, if it is accepted in at least one chain. The
404popularity of accepted models is shown as blue dots. Popularity is defined as
405the percentage of chains accepting the model (100%% meaning acceptance in all
406chains). A moving average of the popularities is shown as blue line (same
407averaging interval as for the acceptance rate). Different background colors
408represent different sampler phases.
409''' % nwindow),
410 fig)
412 mpl_init(fontsize=self.font_size)
413 fig = plt.figure(figsize=self.size_inch)
414 labelpos = mpl_margins(fig, w=7., h=5., units=self.font_size)
416 axes = fig.add_subplot(1, 1, 1)
417 labelpos(axes, 2.5, 2.0)
419 nwindow2 = max(1, int(history.nmodels / (self.size_inch[1] * 100)))
420 nmodels_rate2 = history.nmodels - (nwindow2 - 1)
421 acceptance_rate2 = num.zeros((history.nchains, nmodels_rate2))
422 for ichain in range(history.nchains):
423 acceptance_rate2[ichain, :] = trace.moving_sum(
424 acceptance[ichain, :], nwindow2, mode='valid') \
425 / float(nwindow2)
427 imodels_rate2 = imodels[nwindow2-1:]
429 _pcolormesh_same_dim(
430 axes,
431 imodels_rate2,
432 num.arange(history.nchains),
433 num.log(0.01+acceptance_rate2),
434 cmap='GnBu')
436 if history.sampler_contexts is not None:
437 axes.plot(
438 imodels,
439 history.sampler_contexts[:, 1],
440 '.',
441 ms=2.0,
442 color='black',
443 label='Breeding Chain',
444 alpha=0.3)
446 axes.set_xlabel('Iteration')
447 axes.set_ylabel('Bootstrap Chain')
448 axes.set_xlim(0, history.nmodels - 1)
449 axes.set_ylim(0, history.nchains - 1)
451 axes.xaxis.grid(alpha=.4)
453 yield (
454 PlotItem(
455 name='acceptance_img',
456 description=u'''
457Model acceptance per bootstrap chain averaged over %d models (background color,
458low to high acceptance as light to dark colors).
460Black dots mark the base chains used when sampling new models (directed sampler
461phases only).
462''' % nwindow2),
463 fig)
466__all__ = [
467 'HighScoreOptimiserPlot', 'HighScoreAcceptancePlot']