Coverage for /usr/local/lib/python3.11/dist-packages/grond/optimisers/highscore/plot.py: 18%
222 statements
« prev ^ index » next coverage.py v6.5.0, created at 2023-10-26 16:25 +0000
« prev ^ index » next coverage.py v6.5.0, created at 2023-10-26 16:25 +0000
1from __future__ import print_function
2import logging
3import numpy as num
5from matplotlib import pyplot as plt
6from matplotlib.ticker import FuncFormatter
8from pyrocko.plot import mpl_init, mpl_margins, mpl_color
9from pyrocko.guts import Tuple, Float
10from pyrocko import trace
12from grond.plot.config import PlotConfig
13from grond.plot.collection import PlotItem
15logger = logging.getLogger('grond.optimiser.highscore.plot')
17guts_prefix = 'grond'
20def _pcolormesh_same_dim(ax, x, y, v, **kwargs):
21 # x, y, v must have the same dimension
22 try:
23 return ax.pcolormesh(x, y, v, shading='nearest', **kwargs)
24 except TypeError:
25 # matplotlib versions < 3.3
26 return ax.pcolormesh(x, y, v[:-1, :-1], **kwargs)
29class HighScoreOptimiserPlot(object):
31 def __init__(
32 self, optimiser, problem, history, xpar_name, ypar_name,
33 movie_filename):
35 self.optimiser = optimiser
36 self.problem = problem
37 self.chains = optimiser.chains(problem, history)
38 self.history = history
39 self.xpar_name = xpar_name
40 self.ypar_name = ypar_name
41 self.fontsize = 10.
42 self.movie_filename = movie_filename
43 self.show = False
44 self.iiter = 0
45 self.iiter_last_draw = 0
46 self._volatile = []
47 self._blocks_complete = set()
49 def start(self):
50 nfx = 1
51 nfy = 1
53 problem = self.problem
55 ixpar = problem.name_to_index(self.xpar_name)
56 iypar = problem.name_to_index(self.ypar_name)
58 mpl_init(fontsize=self.fontsize)
59 fig = plt.figure(figsize=(9.6, 5.4))
60 labelpos = mpl_margins(fig, nw=nfx, nh=nfy, w=7., h=5., wspace=7.,
61 hspace=2., units=self.fontsize)
63 xpar = problem.parameters[ixpar]
64 ypar = problem.parameters[iypar]
66 if xpar.unit == ypar.unit:
67 axes = fig.add_subplot(nfy, nfx, 1, aspect=1.0)
68 else:
69 axes = fig.add_subplot(nfy, nfx, 1)
71 labelpos(axes, 2.5, 2.0)
73 axes.set_xlabel(xpar.get_label())
74 axes.set_ylabel(ypar.get_label())
76 axes.get_xaxis().set_major_locator(plt.MaxNLocator(4))
77 axes.get_yaxis().set_major_locator(plt.MaxNLocator(4))
79 xref = problem.get_reference_model()
80 axes.axvline(xpar.scaled(xref[ixpar]), color='black', alpha=0.3)
81 axes.axhline(ypar.scaled(xref[iypar]), color='black', alpha=0.3)
83 self.fig = fig
84 self.problem = problem
85 self.xpar = xpar
86 self.ypar = ypar
87 self.axes = axes
88 self.ixpar = ixpar
89 self.iypar = iypar
90 from matplotlib import colors
91 n = self.optimiser.nbootstrap + 1
92 hsv = num.vstack((
93 num.random.uniform(0., 1., n),
94 num.random.uniform(0.5, 0.9, n),
95 num.repeat(0.7, n))).T
97 self.bcolors = colors.hsv_to_rgb(hsv[num.newaxis, :, :])[0, :, :]
98 self.bcolors[0, :] = [0., 0., 0.]
100 bounds = self.problem.get_combined_bounds()
102 from grond import plot
103 self.xlim = plot.fixlim(*xpar.scaled(bounds[ixpar]))
104 self.ylim = plot.fixlim(*ypar.scaled(bounds[iypar]))
106 self.set_limits()
108 from matplotlib.colors import LinearSegmentedColormap
110 self.cmap = LinearSegmentedColormap.from_list('probability', [
111 (1.0, 1.0, 1.0),
112 (0.5, 0.9, 0.6)])
114 self.writer = None
115 if self.movie_filename:
116 from matplotlib.animation import FFMpegWriter
118 metadata = dict(title=problem.name, artist='Grond')
120 self.writer = FFMpegWriter(
121 fps=30,
122 metadata=metadata,
123 codec='libx264',
124 bitrate=200000,
125 extra_args=[
126 '-pix_fmt', 'yuv420p',
127 '-profile:v', 'baseline',
128 '-level', '3',
129 '-an'])
131 self.writer.setup(self.fig, self.movie_filename, dpi=200)
133 if self.show:
134 plt.ion()
135 plt.show()
137 def set_limits(self):
138 self.axes.autoscale(False)
139 self.axes.set_xlim(*self.xlim)
140 self.axes.set_ylim(*self.ylim)
142 def draw_frame(self):
144 self.chains.goto(self.iiter+1)
145 msize = 15.
147 for artist in self._volatile:
148 artist.remove()
150 self._volatile[:] = []
152 nblocks = self.iiter // 100 + 1
154 models = self.history.models[:self.iiter+1]
156 for iblock in range(nblocks):
157 if iblock in self._blocks_complete:
158 continue
160 models_add = self.history.models[
161 iblock*100:min((iblock+1)*100, self.iiter+1)]
163 fx = self.problem.extract(models_add, self.ixpar)
164 fy = self.problem.extract(models_add, self.iypar)
165 collection = self.axes.scatter(
166 self.xpar.scaled(fx),
167 self.ypar.scaled(fy),
168 color='black',
169 s=msize * 0.15, alpha=0.2, edgecolors='none')
171 if models_add.shape[0] != 100:
172 self._volatile.append(collection)
173 else:
174 self._blocks_complete.add(iblock)
176 for ichain in range(self.chains.nchains):
178 iiters = self.chains.indices(ichain)
179 fx = self.problem.extract(models[iiters, :], self.ixpar)
180 fy = self.problem.extract(models[iiters, :], self.iypar)
182 nfade = 20
183 t1 = num.maximum(0.0, iiters - (models.shape[0] - nfade)) / nfade
184 factors = num.sqrt(1.0 - t1) * (1.0 + 15. * t1**2)
186 msizes = msize * factors
188 paths = self.axes.scatter(
189 self.xpar.scaled(fx),
190 self.ypar.scaled(fy),
191 color=self.bcolors[ichain],
192 s=msizes, alpha=0.5, edgecolors='none')
194 self._volatile.append(paths)
196 _, phase, iiter_phase = self.optimiser.get_sampler_phase(self.iiter)
198 np = 1000
199 models_prob = num.zeros((np, self.problem.nparameters))
200 for ip in range(np):
201 models_prob[ip, :] = phase.get_sample(
202 self.problem, iiter_phase, self.chains)
204 fx = self.problem.extract(models_prob, self.ixpar)
205 fy = self.problem.extract(models_prob, self.iypar)
207 if False:
209 bounds = self.problem.get_combined_bounds()
211 nx = 20
212 ny = 20
213 x_edges = num.linspace(
214 bounds[self.ixpar][0], bounds[self.ixpar][1], nx)
215 y_edges = num.linspace(
216 bounds[self.iypar][0], bounds[self.iypar][1], ny)
218 p, _, _ = num.histogram2d(fx, fy, bins=(x_edges, y_edges))
219 x, y = num.meshgrid(x_edges, y_edges)
221 artist = self.axes.pcolormesh(
222 self.xpar.scaled(x),
223 self.ypar.scaled(y),
224 p, cmap=self.cmap, zorder=-1)
226 self._volatile.append(artist)
228 else:
229 collection = self.axes.scatter(
230 self.xpar.scaled(fx),
231 self.ypar.scaled(fy),
232 color='green',
233 s=msize * 0.15, alpha=0.2, edgecolors='none')
235 self._volatile.append(collection)
237 if self.writer:
238 self.writer.grab_frame()
240 artist = self.axes.annotate(
241 '%i (%s)' % (self.iiter+1, phase.__class__.__name__),
242 xy=(0., 1.),
243 xycoords='axes fraction',
244 xytext=(self.fontsize/2., -self.fontsize/2.),
245 textcoords='offset points',
246 ha='left',
247 va='top',
248 fontsize=self.fontsize,
249 fontstyle='normal')
251 self._volatile.append(artist)
253 if self.show:
254 plt.draw()
256 self.iiter_last_draw = self.iiter + 1
258 def finish(self):
259 if self.writer:
260 self.writer.finish()
262 if self.show:
263 plt.show()
264 plt.ioff()
266 def render(self):
267 self.start()
269 while self.iiter < self.history.nmodels:
270 logger.info('Rendering frame %i/%i.'
271 % (self.iiter+1, self.history.nmodels))
272 self.draw_frame()
273 self.iiter += 1
275 self.finish()
278def rolling_window(a, window):
279 shape = a.shape[:-1] + (a.shape[-1] - window + 1, window)
280 strides = a.strides + (a.strides[-1],)
281 return num.lib.stride_tricks.as_strided(a, shape=shape, strides=strides)
284class HighScoreAcceptancePlot(PlotConfig):
285 '''Model acceptance plot '''
286 name = 'acceptance'
287 size_cm = Tuple.T(2, Float.T(), default=(21., 14.9))
289 def make(self, environ):
290 cm = environ.get_plot_collection_manager()
291 cm.create_group_mpl(
292 self,
293 self.draw_figures(environ),
294 title=u'Acceptance',
295 section='optimiser',
296 description=u'''
297Model acceptance and accepted model popularities.
299The plots in this section can be used to investigate performance and
300characteristics of the optimisation algorithm.
301''',
302 feather_icon='check')
304 def draw_figures(self, environ):
305 nwindow = 200
306 show_raw_acceptance_rates = False
307 optimiser = environ.get_optimiser()
308 problem = environ.get_problem()
309 history = environ.get_history()
310 chains = optimiser.chains(problem, history)
311 chains.load()
313 acceptance = chains.acceptance_history
315 nmodels_rate = history.nmodels - (nwindow - 1)
316 if nmodels_rate < 1:
317 logger.warning(
318 'Cannot create plot acceptance: insufficient number of tested '
319 'models.')
321 return
323 acceptance_rate = num.zeros((history.nchains, nmodels_rate))
324 for ichain in range(history.nchains):
325 acceptance_rate[ichain, :] = trace.moving_sum(
326 acceptance[ichain, :], nwindow, mode='valid') / float(nwindow)
328 acceptance_n = num.sum(acceptance, axis=0)
330 acceptance_any = num.minimum(acceptance_n, 1)
332 acceptance_any_rate = trace.moving_sum(
333 acceptance_any, nwindow, mode='valid') / float(nwindow)
335 acceptance_p = acceptance_n / float(history.nchains)
337 popularity = trace.moving_sum(
338 acceptance_p, nwindow, mode='valid') \
339 / float(nwindow) / acceptance_any_rate
341 mpl_init(fontsize=self.font_size)
342 fig = plt.figure(figsize=self.size_inch)
343 labelpos = mpl_margins(fig, w=7., h=5., units=self.font_size)
345 axes = fig.add_subplot(1, 1, 1)
346 labelpos(axes, 2.5, 2.0)
348 imodels = num.arange(history.nmodels)
350 imodels_rate = imodels[nwindow-1:]
352 axes.plot(
353 acceptance_n/history.nchains * 100.,
354 '.',
355 ms=2.0,
356 color=mpl_color('skyblue2'),
357 label='Popularity of Accepted Models',
358 alpha=0.3)
360 if show_raw_acceptance_rates:
361 for ichain in range(chains.nchains):
362 axes.plot(imodels_rate, acceptance_rate[ichain, :]*100.,
363 color=mpl_color('scarletred2'), alpha=0.2)
365 axes.plot(
366 imodels_rate,
367 popularity * 100.,
368 color=mpl_color('skyblue2'),
369 label='Popularity (moving average)')
370 axes.plot(
371 imodels_rate,
372 acceptance_any_rate*100.,
373 color='black',
374 label='Acceptance Rate (any chain)')
376 axes.legend()
378 axes.set_xlabel('Iteration')
379 axes.set_ylabel('Acceptance Rate, Model Popularity')
381 axes.set_ylim(0., 100.)
382 axes.set_xlim(0., history.nmodels - 1)
383 axes.grid(alpha=.2)
384 axes.yaxis.set_major_formatter(FuncFormatter(lambda v, p: '%d%%' % v))
386 iiter = 0
387 bgcolors = [mpl_color('aluminium1'), mpl_color('aluminium2')]
388 for iphase, phase in enumerate(optimiser.sampler_phases):
389 axes.axvspan(
390 iiter, iiter+phase.niterations,
391 color=bgcolors[iphase % len(bgcolors)])
393 iiter += phase.niterations
395 yield (
396 PlotItem(
397 name='acceptance',
398 description=u'''
399Acceptance rate (black line) within a moving window of %d iterations.
401A model is considered accepted, if it is accepted in at least one chain. The
402popularity of accepted models is shown as blue dots. Popularity is defined as
403the percentage of chains accepting the model (100%% meaning acceptance in all
404chains). A moving average of the popularities is shown as blue line (same
405averaging interval as for the acceptance rate). Different background colors
406represent different sampler phases.
407''' % nwindow),
408 fig)
410 mpl_init(fontsize=self.font_size)
411 fig = plt.figure(figsize=self.size_inch)
412 labelpos = mpl_margins(fig, w=7., h=5., units=self.font_size)
414 axes = fig.add_subplot(1, 1, 1)
415 labelpos(axes, 2.5, 2.0)
417 nwindow2 = max(1, int(history.nmodels / (self.size_inch[1] * 100)))
418 nmodels_rate2 = history.nmodels - (nwindow2 - 1)
419 acceptance_rate2 = num.zeros((history.nchains, nmodels_rate2))
420 for ichain in range(history.nchains):
421 acceptance_rate2[ichain, :] = trace.moving_sum(
422 acceptance[ichain, :], nwindow2, mode='valid') \
423 / float(nwindow2)
425 imodels_rate2 = imodels[nwindow2-1:]
427 _pcolormesh_same_dim(
428 axes,
429 imodels_rate2,
430 num.arange(history.nchains),
431 num.log(0.01+acceptance_rate2),
432 cmap='GnBu')
434 if history.sampler_contexts is not None:
435 axes.plot(
436 imodels,
437 history.sampler_contexts[:, 1],
438 '.',
439 ms=2.0,
440 color='black',
441 label='Breeding Chain',
442 alpha=0.3)
444 axes.set_xlabel('Iteration')
445 axes.set_ylabel('Bootstrap Chain')
446 axes.set_xlim(0, history.nmodels - 1)
447 axes.set_ylim(0, history.nchains - 1)
449 axes.xaxis.grid(alpha=.4)
451 yield (
452 PlotItem(
453 name='acceptance_img',
454 description=u'''
455Model acceptance per bootstrap chain averaged over %d models (background color,
456low to high acceptance as light to dark colors).
458Black dots mark the base chains used when sampling new models (directed sampler
459phases only).
460''' % nwindow2),
461 fig)
464__all__ = [
465 'HighScoreOptimiserPlot', 'HighScoreAcceptancePlot']